functorch.hessian¶
-
functorch.
hessian
(func, argnums=0)[source]¶ Computes the Hessian of
func
with respect to the arg(s) at indexargnum
via a forward-over-reverse strategy.The forward-over-reverse strategy (composing
jacfwd(jacrev(func))
) is a good default for good performance. It is possible to compute Hessians through other compositions ofjacfwd()
andjacrev()
likejacfwd(jacfwd(func))
orjacrev(jacrev(func))
.- Parameters
- Returns
Returns a function that takes in the same inputs as
func
and returns the Hessian offunc
with respect to the arg(s) atargnums
.
Note
You may see this API error out with “forward-mode AD not implemented for operator X”. If so, please file a bug report and we will prioritize it. An alternative is to use
jacrev(jacrev(func))
, which has better operator coverage.A basic usage with a R^N -> R^1 function gives a N x N Hessian:
>>> from torch.func import hessian >>> def f(x): >>> return x.sin().sum() >>> >>> x = torch.randn(5) >>> hess = hessian(f)(x) # equivalent to jacfwd(jacrev(f))(x) >>> assert torch.allclose(hess, torch.diag(-x.sin()))
Warning
We’ve integrated functorch into PyTorch. As the final step of the integration, functorch.hessian is deprecated as of PyTorch 2.0 and will be deleted in a future version of PyTorch >= 2.3. Please use torch.func.hessian instead; see the PyTorch 2.0 release notes and/or the torch.func migration guide for more details https://pytorch.org/docs/master/func.migrating.html