Shortcuts

functorch.jvp

functorch.jvp(func, primals, tangents, *, strict=False, has_aux=False)[source]

Standing for the Jacobian-vector product, returns a tuple containing the output of func(*primals) and the “Jacobian of func evaluated at primals” times tangents. This is also known as forward-mode autodiff.

Parameters
  • func (function) – A Python function that takes one or more arguments, one of which must be a Tensor, and returns one or more Tensors

  • primals (Tensors) – Positional arguments to func that must all be Tensors. The returned function will also be computing the derivative with respect to these arguments

  • tangents (Tensors) – The “vector” for which Jacobian-vector-product is computed. Must be the same structure and sizes as the inputs to func.

  • has_aux (bool) – Flag indicating that func returns a (output, aux) tuple where the first element is the output of the function to be differentiated and the second element is other auxiliary objects that will not be differentiated. Default: False.

Returns

Returns a (output, jvp_out) tuple containing the output of func evaluated at primals and the Jacobian-vector product. If has_aux is True, then instead returns a (output, jvp_out, aux) tuple.

Note

You may see this API error out with “forward-mode AD not implemented for operator X”. If so, please file a bug report and we will prioritize it.

jvp is useful when you wish to compute gradients of a function R^1 -> R^N

>>> from torch.func import jvp
>>> x = torch.randn([])
>>> f = lambda x: x * torch.tensor([1., 2., 3])
>>> value, grad = jvp(f, (x,), (torch.tensor(1.),))
>>> assert torch.allclose(value, f(x))
>>> assert torch.allclose(grad, torch.tensor([1., 2, 3]))

jvp() can support functions with multiple inputs by passing in the tangents for each of the inputs

>>> from torch.func import jvp
>>> x = torch.randn(5)
>>> y = torch.randn(5)
>>> f = lambda x, y: (x * y)
>>> _, output = jvp(f, (x, y), (torch.ones(5), torch.ones(5)))
>>> assert torch.allclose(output, x + y)

Warning

We’ve integrated functorch into PyTorch. As the final step of the integration, functorch.jvp is deprecated as of PyTorch 2.0 and will be deleted in a future version of PyTorch >= 2.3. Please use torch.func.jvp instead; see the PyTorch 2.0 release notes and/or the torch.func migration guide for more details https://pytorch.org/docs/main/func.migrating.html

Docs

Access comprehensive developer documentation for PyTorch

View Docs

Tutorials

Get in-depth tutorials for beginners and advanced developers

View Tutorials

Resources

Find development resources and get your questions answered

View Resources