Shortcuts

torch.func.hessian

torch.func.hessian(func, argnums=0)

Computes the Hessian of func with respect to the arg(s) at index argnum via a forward-over-reverse strategy.

The forward-over-reverse strategy (composing jacfwd(jacrev(func))) is a good default for good performance. It is possible to compute Hessians through other compositions of jacfwd() and jacrev() like jacfwd(jacfwd(func)) or jacrev(jacrev(func)).

Parameters:
  • func (function) – A Python function that takes one or more arguments, one of which must be a Tensor, and returns one or more Tensors

  • argnums (int or Tuple[int]) – Optional, integer or tuple of integers, saying which arguments to get the Hessian with respect to. Default: 0.

Returns:

Returns a function that takes in the same inputs as func and returns the Hessian of func with respect to the arg(s) at argnums.

Note

You may see this API error out with “forward-mode AD not implemented for operator X”. If so, please file a bug report and we will prioritize it. An alternative is to use jacrev(jacrev(func)), which has better operator coverage.

A basic usage with a R^N -> R^1 function gives a N x N Hessian:

>>> from torch.func import hessian
>>> def f(x):
>>>   return x.sin().sum()
>>>
>>> x = torch.randn(5)
>>> hess = hessian(f)(x)  # equivalent to jacfwd(jacrev(f))(x)
>>> assert torch.allclose(hess, torch.diag(-x.sin()))

Docs

Access comprehensive developer documentation for PyTorch

View Docs

Tutorials

Get in-depth tutorials for beginners and advanced developers

View Tutorials

Resources

Find development resources and get your questions answered

View Resources