Shortcuts

torch.func.grad_and_value

torch.func.grad_and_value(func, argnums=0, has_aux=False)

Returns a function to compute a tuple of the gradient and primal, or forward, computation.

Parameters
  • func (Callable) – A Python function that takes one or more arguments. Must return a single-element Tensor. If specified has_aux equals True, function can return a tuple of single-element Tensor and other auxiliary objects: (output, aux).

  • argnums (int or Tuple[int]) – Specifies arguments to compute gradients with respect to. argnums can be single integer or tuple of integers. Default: 0.

  • has_aux (bool) – Flag indicating that func returns a tensor and other auxiliary objects: (output, aux). Default: False.

Returns

Function to compute a tuple of gradients with respect to its inputs and the forward computation. By default, the output of the function is a tuple of the gradient tensor(s) with respect to the first argument and the primal computation. If specified has_aux equals True, tuple of gradients and tuple of the forward computation with output auxiliary objects is returned. If argnums is a tuple of integers, a tuple of a tuple of the output gradients with respect to each argnums value and the forward computation is returned.

Return type

Callable

See grad() for examples

Docs

Access comprehensive developer documentation for PyTorch

View Docs

Tutorials

Get in-depth tutorials for beginners and advanced developers

View Tutorials

Resources

Find development resources and get your questions answered

View Resources