functorch.vjp

functorch.vjp(f, *primals)[source]

Standing for the vector-Jacobian product, returns a tuple containing the results of f applied to primals and a function that, when given cotangents, computes the reverse-mode Jacobian of f with respect to primals times cotangents.

Parameters
  • f (Callable) – A Python function that takes one or more arguments. Must return one or more Tensors.

  • primals (Tensors) – Positional arguments to f that must all be Tensors. The returned function will also be computing the derivative with respect to these arguments

Returns

Returns a tuple containing the output of f applied to primals and a function that computes the vjp of f with respect to all primals using the cotangents passed to the returned function. The returned function will return a tuple of each VJP

When used in simple cases, vjp() behaves the same as grad()

>>> x = torch.randn([5])
>>> f = lambda x: x.sin().sum()
>>> (_, vjpfunc) = functorch.vjp(f, x)
>>> grad = vjpfunc(torch.tensor(1.))[0]
>>> assert torch.allclose(grad, functorch.grad(f)(x))

However, vjp() can support functions with multiple outputs by passing in the cotangents for each of the outputs

>>> x = torch.randn([5])
>>> f = lambda x: (x.sin(), x.cos())
>>> (_, vjpfunc) = functorch.vjp(f, x)
>>> vjps = vjpfunc((torch.ones([5]), torch.ones([5])))
>>> assert torch.allclose(vjps[0], x.cos() + -x.sin())

vjp() can even support outputs being Python structs

>>> x = torch.randn([5])
>>> f = lambda x: {'first': x.sin(), 'second': x.cos()}
>>> (_, vjpfunc) = functorch.vjp(f, x)
>>> cotangents = {'first': torch.ones([5]), 'second': torch.ones([5])}
>>> vjps = vjpfunc((cotangents,))
>>> assert torch.allclose(vjps[0], x.cos() + -x.sin())

The function returned by vjp() will compute the partials with respect to each of the primals

>>> x, y = torch.randn([5, 4]), torch.randn([4, 5])
>>> (_, vjpfunc) = functorch.vjp(torch.matmul, x, y)
>>> cotangents = torch.randn([5, 5])
>>> vjps = vjpfunc(cotangents)
>>> assert len(vjps) == 2
>>> assert torch.allclose(vjps[0], torch.matmul(cotangents, y.transpose(0, 1)))
>>> assert torch.allclose(vjps[1], torch.matmul(x.transpose(0, 1), cotangents))

primals are the positional arguments for f. All kwargs use their default value

>>> x = torch.randn([5])
>>> def f(x, scale=4.):
>>>   return x * 4.
>>>
>>> (_, vjpfunc) = functorch.vjp(f, x)
>>> vjps = vjpfunc(torch.ones_like(x))
>>> assert torch.allclose(vjps[0], torch.full(x.shape, 4.))

Note

Using PyTorch torch.no_grad together with vjp. Case 1: Using torch.no_grad inside a function:

>>> def f(x):
>>>     with torch.no_grad():
>>>         c = x ** 2
>>>     return x - c

In this case, vjp(f)(x) will respect the inner torch.no_grad.

Case 2: Using vjp inside torch.no_grad context manager:

>>> with torch.no_grad():
>>>     vjp(f)(x)

In this case, vjp will respect the inner torch.no_grad, but not the outer one. This is because vjp is a “function transform”: its result should not depend on the result of a context manager outside of f.