functorch.jacrev¶

functorch.
jacrev
(f, argnums=0)[source]¶ Computes the Jacobian of
f
with respect to the arg(s) at indexargnum
using reverse mode autodiff Parameters
 Returns
Returns a function that takes in the same inputs as
f
and returns the Jacobian off
with respect to the arg(s) atargnums
A basic usage with a pointwise, unary operation will give a diagonal array as the Jacobian
>>> from functorch import jacrev >>> x = torch.randn(5) >>> jacobian = jacrev(torch.sin)(x) >>> expected = torch.diag(torch.cos(x)) >>> assert torch.allclose(jacobian, expected)
jacrev()
can be composed with vmap to produce batched Jacobians:>>> from functorch import jacrev >>> x = torch.randn(64, 5) >>> jacobian = vmap(jacrev(torch.sin))(x) >>> assert jacobian.shape == (64, 5, 5)
Additionally,
jacrev()
can be composed with itself to produce Hessians>>> from functorch import jacrev >>> def f(x): >>> return x.sin().sum() >>> >>> x = torch.randn(5) >>> hessian = jacrev(jacrev(f))(x) >>> assert torch.allclose(hessian, torch.diag(x.sin()))
By default,
jacrev()
computes the Jacobian with respect to the first input. However, it can compute the Jacboian with respect to a different argument by usingargnums
:>>> from functorch import jacrev >>> def f(x, y): >>> return x + y ** 2 >>> >>> x, y = torch.randn(5), torch.randn(5) >>> jacobian = jacrev(f, argnums=1)(x, y) >>> expected = torch.diag(2 * y) >>> assert torch.allclose(jacobian, expected)
Additionally, passing a tuple to
argnums
will compute the Jacobian with respect to multiple arguments>>> from functorch import jacrev >>> def f(x, y): >>> return x + y ** 2 >>> >>> x, y = torch.randn(5), torch.randn(5) >>> jacobian = jacrev(f, argnums=(0,1))(x, y) >>> expectedX = torch.diag(torch.ones_like(x)) >>> expectedY = torch.diag(2 * y) >>> assert torch.allclose(jacobian[0], expectedX) >>> assert torch.allclose(jacobian[1], expectedY)
Note
Using PyTorch
torch.no_grad
together withjacrev
. Case 1: Usingtorch.no_grad
inside a function:>>> def f(x): >>> with torch.no_grad(): >>> c = x ** 2 >>> return x  c
In this case,
jacrev(f)(x)
will respect the innertorch.no_grad
.Case 2: Using
jacrev
insidetorch.no_grad
context manager:>>> with torch.no_grad(): >>> jacrev(f)(x)
In this case,
jacrev
will respect the innertorch.no_grad
, but not the outer one. This is becausejacrev
is a “function transform”: its result should not depend on the result of a context manager outside off
.