torch.func.grad(func, argnums=0, has_aux=False)

grad operator helps computing gradients of func with respect to the input(s) specified by argnums. This operator can be nested to compute higher-order gradients.

  • func (Callable) – A Python function that takes one or more arguments. Must return a single-element Tensor. If specified has_aux equals True, function can return a tuple of single-element Tensor and other auxiliary objects: (output, aux).

  • argnums (int or Tuple[int]) – Specifies arguments to compute gradients with respect to. argnums can be single integer or tuple of integers. Default: 0.

  • has_aux (bool) – Flag indicating that func returns a tensor and other auxiliary objects: (output, aux). Default: False.


Function to compute gradients with respect to its inputs. By default, the output of the function is the gradient tensor(s) with respect to the first argument. If specified has_aux equals True, tuple of gradients and output auxiliary objects is returned. If argnums is a tuple of integers, a tuple of output gradients with respect to each argnums value is returned.

Return type


Example of using grad:

>>> from torch.func import grad
>>> x = torch.randn([])
>>> cos_x = grad(lambda x: torch.sin(x))(x)
>>> assert torch.allclose(cos_x, x.cos())
>>> # Second-order gradients
>>> neg_sin_x = grad(grad(lambda x: torch.sin(x)))(x)
>>> assert torch.allclose(neg_sin_x, -x.sin())

When composed with vmap, grad can be used to compute per-sample-gradients:

>>> from torch.func import grad, vmap
>>> batch_size, feature_size = 3, 5
>>> def model(weights, feature_vec):
>>>     # Very simple linear model with activation
>>>     assert feature_vec.dim() == 1
>>>     return
>>> def compute_loss(weights, example, target):
>>>     y = model(weights, example)
>>>     return ((y - target) ** 2).mean()  # MSELoss
>>> weights = torch.randn(feature_size, requires_grad=True)
>>> examples = torch.randn(batch_size, feature_size)
>>> targets = torch.randn(batch_size)
>>> inputs = (weights, examples, targets)
>>> grad_weight_per_example = vmap(grad(compute_loss), in_dims=(None, 0, 0))(*inputs)

Example of using grad with has_aux and argnums:

>>> from torch.func import grad
>>> def my_loss_func(y, y_pred):
>>>    loss_per_sample = (0.5 * y_pred - y) ** 2
>>>    loss = loss_per_sample.mean()
>>>    return loss, (y_pred, loss_per_sample)
>>> fn = grad(my_loss_func, argnums=(0, 1), has_aux=True)
>>> y_true = torch.rand(4)
>>> y_preds = torch.rand(4, requires_grad=True)
>>> out = fn(y_true, y_preds)
>>> # > output is ((grads w.r.t y_true, grads w.r.t y_preds), (y_pred, loss_per_sample))


Using PyTorch torch.no_grad together with grad.

Case 1: Using torch.no_grad inside a function:

>>> def f(x):
>>>     with torch.no_grad():
>>>         c = x ** 2
>>>     return x - c

In this case, grad(f)(x) will respect the inner torch.no_grad.

Case 2: Using grad inside torch.no_grad context manager:

>>> with torch.no_grad():
>>>     grad(f)(x)

In this case, grad will respect the inner torch.no_grad, but not the outer one. This is because grad is a “function transform”: its result should not depend on the result of a context manager outside of f.


Access comprehensive developer documentation for PyTorch

View Docs


Get in-depth tutorials for beginners and advanced developers

View Tutorials


Find development resources and get your questions answered

View Resources