functorch.vjp¶
-
functorch.
vjp
(f, *primals)[source]¶ Standing for the vector-Jacobian product, returns a tuple containing the results of
f
applied toprimals
and a function that, when givencotangents
, computes the reverse-mode Jacobian off
with respect toprimals
timescotangents
.- Parameters
f (Callable) – A Python function that takes one or more arguments. Must return one or more Tensors.
primals (Tensors) – Positional arguments to
f
that must all be Tensors. The returned function will also be computing the derivative with respect to these arguments
- Returns
Returns a tuple containing the output of
f
applied toprimals
and a function that computes the vjp off
with respect to allprimals
using the cotangents passed to the returned function. The returned function will return a tuple of each VJP
When used in simple cases,
vjp()
behaves the same asgrad()
>>> x = torch.randn([5]) >>> f = lambda x: x.sin().sum() >>> (_, vjpfunc) = functorch.vjp(f, x) >>> grad = vjpfunc(torch.tensor(1.))[0] >>> assert torch.allclose(grad, functorch.grad(f)(x))
However,
vjp()
can support functions with multiple outputs by passing in the cotangents for each of the outputs>>> x = torch.randn([5]) >>> f = lambda x: (x.sin(), x.cos()) >>> (_, vjpfunc) = functorch.vjp(f, x) >>> vjps = vjpfunc((torch.ones([5]), torch.ones([5]))) >>> assert torch.allclose(vjps[0], x.cos() + -x.sin())
vjp()
can even support outputs being Python structs>>> x = torch.randn([5]) >>> f = lambda x: {'first': x.sin(), 'second': x.cos()} >>> (_, vjpfunc) = functorch.vjp(f, x) >>> cotangents = {'first': torch.ones([5]), 'second': torch.ones([5])} >>> vjps = vjpfunc((cotangents,)) >>> assert torch.allclose(vjps[0], x.cos() + -x.sin())
The function returned by
vjp()
will compute the partials with respect to each of theprimals
>>> x, y = torch.randn([5, 4]), torch.randn([4, 5]) >>> (_, vjpfunc) = functorch.vjp(torch.matmul, x, y) >>> cotangents = torch.randn([5, 5]) >>> vjps = vjpfunc(cotangents) >>> assert len(vjps) == 2 >>> assert torch.allclose(vjps[0], torch.matmul(cotangents, y.transpose(0, 1))) >>> assert torch.allclose(vjps[1], torch.matmul(x.transpose(0, 1), cotangents))
primals
are the positional arguments forf
. All kwargs use their default value>>> x = torch.randn([5]) >>> def f(x, scale=4.): >>> return x * 4. >>> >>> (_, vjpfunc) = functorch.vjp(f, x) >>> vjps = vjpfunc(torch.ones_like(x)) >>> assert torch.allclose(vjps[0], torch.full(x.shape, 4.))
Note
Using PyTorch
torch.no_grad
together withvjp
. Case 1: Usingtorch.no_grad
inside a function:>>> def f(x): >>> with torch.no_grad(): >>> c = x ** 2 >>> return x - c
In this case,
vjp(f)(x)
will respect the innertorch.no_grad
.Case 2: Using
vjp
insidetorch.no_grad
context manager:>>> with torch.no_grad(): >>> vjp(f)(x)
In this case,
vjp
will respect the innertorch.no_grad
, but not the outer one. This is becausevjp
is a “function transform”: its result should not depend on the result of a context manager outside off
.