torch.func.grad_and_value
- torch.func.grad_and_value(func, argnums=0, has_aux=False)
Returns a function to compute a tuple of the gradient and primal, or forward, computation.
- Parameters
func (Callable) – A Python function that takes one or more arguments. Must return a single-element Tensor. If specified
has_aux
equalsTrue
, function can return a tuple of single-element Tensor and other auxiliary objects:(output, aux)
.argnums (int or Tuple[int]) – Specifies arguments to compute gradients with respect to.
argnums
can be single integer or tuple of integers. Default: 0.has_aux (bool) – Flag indicating that
func
returns a tensor and other auxiliary objects:(output, aux)
. Default: False.
- Returns
Function to compute a tuple of gradients with respect to its inputs and the forward computation. By default, the output of the function is a tuple of the gradient tensor(s) with respect to the first argument and the primal computation. If specified
has_aux
equalsTrue
, tuple of gradients and tuple of the forward computation with output auxiliary objects is returned. Ifargnums
is a tuple of integers, a tuple of a tuple of the output gradients with respect to eachargnums
value and the forward computation is returned.- Return type
See
grad()
for examples