functorch.jvp¶
-
functorch.
jvp
(func, primals, tangents, *, strict=False, has_aux=False)[source]¶ Standing for the Jacobian-vector product, returns a tuple containing the output of func(*primals) and the “Jacobian of
func
evaluated atprimals
” timestangents
. This is also known as forward-mode autodiff.- Parameters
func (function) – A Python function that takes one or more arguments, one of which must be a Tensor, and returns one or more Tensors
primals (Tensors) – Positional arguments to
func
that must all be Tensors. The returned function will also be computing the derivative with respect to these argumentstangents (Tensors) – The “vector” for which Jacobian-vector-product is computed. Must be the same structure and sizes as the inputs to
func
.has_aux (bool) – Flag indicating that
func
returns a(output, aux)
tuple where the first element is the output of the function to be differentiated and the second element is other auxiliary objects that will not be differentiated. Default: False.
- Returns
Returns a
(output, jvp_out)
tuple containing the output offunc
evaluated atprimals
and the Jacobian-vector product. Ifhas_aux is True
, then instead returns a(output, jvp_out, aux)
tuple.
Note
You may see this API error out with “forward-mode AD not implemented for operator X”. If so, please file a bug report and we will prioritize it.
jvp is useful when you wish to compute gradients of a function R^1 -> R^N
>>> from torch.func import jvp >>> x = torch.randn([]) >>> f = lambda x: x * torch.tensor([1., 2., 3]) >>> value, grad = jvp(f, (x,), (torch.tensor(1.),)) >>> assert torch.allclose(value, f(x)) >>> assert torch.allclose(grad, torch.tensor([1., 2, 3]))
jvp()
can support functions with multiple inputs by passing in the tangents for each of the inputs>>> from torch.func import jvp >>> x = torch.randn(5) >>> y = torch.randn(5) >>> f = lambda x, y: (x * y) >>> _, output = jvp(f, (x, y), (torch.ones(5), torch.ones(5))) >>> assert torch.allclose(output, x + y)
Warning
We’ve integrated functorch into PyTorch. As the final step of the integration, functorch.jvp is deprecated as of PyTorch 2.0 and will be deleted in a future version of PyTorch >= 2.3. Please use torch.func.jvp instead; see the PyTorch 2.0 release notes and/or the torch.func migration guide for more details https://pytorch.org/docs/master/func.migrating.html