torch.func.linearize(func, *primals)

Returns the value of func at primals and linear approximation at primals.

  • func (Callable) – A Python function that takes one or more arguments.

  • primals (Tensors) – Positional arguments to func that must all be Tensors. These are the values at which the function is linearly approximated.


Returns a (output, jvp_fn) tuple containing the output of func applied to primals and a function that computes the jvp of func evaluated at primals.

Return type

Tuple[Any, Callable]

linearize is useful if jvp is to be computed multiple times at primals. However, to achieve this, linearize saves intermediate computation and has higher memory requirements than directly applying jvp. So, if all the tangents are known, it maybe more efficient to compute vmap(jvp) instead of using linearize.


linearize evaluates func twice. Please file an issue for an implementation with a single evaluation.

>>> import torch
>>> from torch.func import linearize
>>> def fn(x):
...     return x.sin()
>>> output, jvp_fn = linearize(fn, torch.zeros(3, 3))
>>> jvp_fn(torch.ones(3, 3))
tensor([[1., 1., 1.],
        [1., 1., 1.],
        [1., 1., 1.]])


Access comprehensive developer documentation for PyTorch

View Docs


Get in-depth tutorials for beginners and advanced developers

View Tutorials


Find development resources and get your questions answered

View Resources