jvp(func, primals, tangents, *, strict=False, has_aux=False)¶
Standing for the Jacobian-vector product, returns a tuple containing the output of func(*primals) and the “Jacobian of
tangents. This is also known as forward-mode autodiff.
func (function) – A Python function that takes one or more arguments, one of which must be a Tensor, and returns one or more Tensors
primals (Tensors) – Positional arguments to
functhat must all be Tensors. The returned function will also be computing the derivative with respect to these arguments
tangents (Tensors) – The “vector” for which Jacobian-vector-product is computed. Must be the same structure and sizes as the inputs to
has_aux (bool) – Flag indicating that
(output, aux)tuple where the first element is the output of the function to be differentiated and the second element is other auxiliary objects that will not be differentiated. Default: False.
(output, jvp_out)tuple containing the output of
primalsand the Jacobian-vector product. If
has_aux is True, then instead returns a
(output, jvp_out, aux)tuple.
You may see this API error out with “forward-mode AD not implemented for operator X”. If so, please file a bug report and we will prioritize it.
jvp is useful when you wish to compute gradients of a function R^1 -> R^N
>>> from torch.func import jvp >>> x = torch.randn() >>> f = lambda x: x * torch.tensor([1., 2., 3]) >>> value, grad = jvp(f, (x,), (torch.tensor(1.),)) >>> assert torch.allclose(value, f(x)) >>> assert torch.allclose(grad, torch.tensor([1., 2, 3]))
jvp()can support functions with multiple inputs by passing in the tangents for each of the inputs
>>> from torch.func import jvp >>> x = torch.randn(5) >>> y = torch.randn(5) >>> f = lambda x, y: (x * y) >>> _, output = jvp(f, (x, y), (torch.ones(5), torch.ones(5))) >>> assert torch.allclose(output, x + y)
We’ve integrated functorch into PyTorch. As the final step of the integration, functorch.jvp is deprecated as of PyTorch 2.0 and will be deleted in a future version of PyTorch >= 2.3. Please use torch.func.jvp instead; see the PyTorch 2.0 release notes and/or the torch.func migration guide for more details https://pytorch.org/docs/master/func.migrating.html