functorch.jacrev

functorch.jacrev(f, argnums=0)[source]

Computes the Jacobian of f with respect to the arg(s) at index argnum using reverse mode autodiff

Parameters
  • func (function) – A Python function that takes one or more arguments, one of which must be a Tensor, and returns one or more Tensors

  • argnums (int or Tuple[int]) – Optional, integer or tuple of integers, saying which arguments to get the Jacobian with respect to. Default: 0.

Returns

Returns a function that takes in the same inputs as f and returns the Jacobian of f with respect to the arg(s) at argnums

A basic usage with a pointwise, unary operation will give a diagonal array as the Jacobian

>>> from functorch import jacrev
>>> x = torch.randn(5)
>>> jacobian = jacrev(torch.sin)(x)
>>> expected = torch.diag(torch.cos(x))
>>> assert torch.allclose(jacobian, expected)

jacrev() can be composed with vmap to produce batched Jacobians:

>>> from functorch import jacrev
>>> x = torch.randn(64, 5)
>>> jacobian = vmap(jacrev(torch.sin))(x)
>>> assert jacobian.shape == (64, 5, 5)

Additionally, jacrev() can be composed with itself to produce Hessians

>>> from functorch import jacrev
>>> def f(x):
>>>   return x.sin().sum()
>>>
>>> x = torch.randn(5)
>>> hessian = jacrev(jacrev(f))(x)
>>> assert torch.allclose(hessian, torch.diag(-x.sin()))

By default, jacrev() computes the Jacobian with respect to the first input. However, it can compute the Jacboian with respect to a different argument by using argnums:

>>> from functorch import jacrev
>>> def f(x, y):
>>>   return x + y ** 2
>>>
>>> x, y = torch.randn(5), torch.randn(5)
>>> jacobian = jacrev(f, argnums=1)(x, y)
>>> expected = torch.diag(2 * y)
>>> assert torch.allclose(jacobian, expected)

Additionally, passing a tuple to argnums will compute the Jacobian with respect to multiple arguments

>>> from functorch import jacrev
>>> def f(x, y):
>>>   return x + y ** 2
>>>
>>> x, y = torch.randn(5), torch.randn(5)
>>> jacobian = jacrev(f, argnums=(0,1))(x, y)
>>> expectedX = torch.diag(torch.ones_like(x))
>>> expectedY = torch.diag(2 * y)
>>> assert torch.allclose(jacobian[0], expectedX)
>>> assert torch.allclose(jacobian[1], expectedY)

Note

Using PyTorch torch.no_grad together with jacrev. Case 1: Using torch.no_grad inside a function:

>>> def f(x):
>>>     with torch.no_grad():
>>>         c = x ** 2
>>>     return x - c

In this case, jacrev(f)(x) will respect the inner torch.no_grad.

Case 2: Using jacrev inside torch.no_grad context manager:

>>> with torch.no_grad():
>>>     jacrev(f)(x)

In this case, jacrev will respect the inner torch.no_grad, but not the outer one. This is because jacrev is a “function transform”: its result should not depend on the result of a context manager outside of f.