functorch.jacrev¶
-
functorch.
jacrev
(func, argnums=0, *, has_aux=False)[source]¶ Computes the Jacobian of
func
with respect to the arg(s) at indexargnum
using reverse mode autodiff- Parameters
func (function) – A Python function that takes one or more arguments, one of which must be a Tensor, and returns one or more Tensors
argnums (int or Tuple[int]) – Optional, integer or tuple of integers, saying which arguments to get the Jacobian with respect to. Default: 0.
has_aux (bool) – Flag indicating that
func
returns a(output, aux)
tuple where the first element is the output of the function to be differentiated and the second element is auxiliary objects that will not be differentiated. Default: False.
- Returns
Returns a function that takes in the same inputs as
func
and returns the Jacobian offunc
with respect to the arg(s) atargnums
. Ifhas_aux is True
, then the returned function instead returns a(jacobian, aux)
tuple wherejacobian
is the Jacobian andaux
is auxiliary objects returned byfunc
.
A basic usage with a pointwise, unary operation will give a diagonal array as the Jacobian
>>> from functorch import jacrev >>> x = torch.randn(5) >>> jacobian = jacrev(torch.sin)(x) >>> expected = torch.diag(torch.cos(x)) >>> assert torch.allclose(jacobian, expected)
If you would like to compute the output of the function as well as the jacobian of the function, use the
has_aux
flag to return the output as an auxiliary object:>>> from functorch import jacrev >>> x = torch.randn(5) >>> >>> def f(x): >>> return x.sin() >>> >>> def g(x): >>> result = f(x) >>> return result, result >>> >>> jacobian_f, f_x = jacrev(g, has_aux=True)(x) >>> assert torch.allclose(f_x, f(x))
jacrev()
can be composed with vmap to produce batched Jacobians:>>> from functorch import jacrev, vmap >>> x = torch.randn(64, 5) >>> jacobian = vmap(jacrev(torch.sin))(x) >>> assert jacobian.shape == (64, 5, 5)
Additionally,
jacrev()
can be composed with itself to produce Hessians>>> from functorch import jacrev >>> def f(x): >>> return x.sin().sum() >>> >>> x = torch.randn(5) >>> hessian = jacrev(jacrev(f))(x) >>> assert torch.allclose(hessian, torch.diag(-x.sin()))
By default,
jacrev()
computes the Jacobian with respect to the first input. However, it can compute the Jacboian with respect to a different argument by usingargnums
:>>> from functorch import jacrev >>> def f(x, y): >>> return x + y ** 2 >>> >>> x, y = torch.randn(5), torch.randn(5) >>> jacobian = jacrev(f, argnums=1)(x, y) >>> expected = torch.diag(2 * y) >>> assert torch.allclose(jacobian, expected)
Additionally, passing a tuple to
argnums
will compute the Jacobian with respect to multiple arguments>>> from functorch import jacrev >>> def f(x, y): >>> return x + y ** 2 >>> >>> x, y = torch.randn(5), torch.randn(5) >>> jacobian = jacrev(f, argnums=(0, 1))(x, y) >>> expectedX = torch.diag(torch.ones_like(x)) >>> expectedY = torch.diag(2 * y) >>> assert torch.allclose(jacobian[0], expectedX) >>> assert torch.allclose(jacobian[1], expectedY)
Note
Using PyTorch
torch.no_grad
together withjacrev
. Case 1: Usingtorch.no_grad
inside a function:>>> def f(x): >>> with torch.no_grad(): >>> c = x ** 2 >>> return x - c
In this case,
jacrev(f)(x)
will respect the innertorch.no_grad
.Case 2: Using
jacrev
insidetorch.no_grad
context manager:>>> with torch.no_grad(): >>> jacrev(f)(x)
In this case,
jacrev
will respect the innertorch.no_grad
, but not the outer one. This is becausejacrev
is a “function transform”: its result should not depend on the result of a context manager outside off
.