functorch.grad_and_value
-
functorch.
grad_and_value
(func, argnums=0, has_aux=False)[source] Returns a function to compute a tuple of the gradient and primal, or forward, computation.
- Parameters
func (Callable) – A Python function that takes one or more arguments. Must return a single-element Tensor. If specified
has_aux
equalsTrue
, function can return a tuple of single-element Tensor and other auxiliary objects:(output, aux)
.argnums (int or Tuple[int]) – Specifies arguments to compute gradients with respect to.
argnums
can be single integer or tuple of integers. Default: 0.has_aux (bool) – Flag indicating that
func
returns a tensor and other auxiliary objects:(output, aux)
. Default: False.
- Returns
Function to compute a tuple of gradients with respect to its inputs and the forward computation. By default, the output of the function is a tuple of the gradient tensor(s) with respect to the first argument and the primal computation. If specified
has_aux
equalsTrue
, tuple of gradients and tuple of the forward computation with output auxiliary objects is returned. Ifargnums
is a tuple of integers, a tuple of a tuple of the output gradients with respect to eachargnums
value and the forward computation is returned.
See
grad()
for examplesWarning
We’ve integrated functorch into PyTorch. As the final step of the integration, functorch.grad_and_value is deprecated as of PyTorch 2.0 and will be deleted in a future version of PyTorch >= 2.3. Please use torch.func.grad_and_value instead; see the PyTorch 2.0 release notes and/or the torch.func migration guide for more details https://pytorch.org/docs/master/func.migrating.html