Shortcuts

torch.func.linearize

torch.func.linearize(func, *primals)

Returns the value of func at primals and linear approximation at primals.

Parameters
  • func (Callable) – A Python function that takes one or more arguments.

  • primals (Tensors) – Positional arguments to func that must all be Tensors. These are the values at which the function is linearly approximated.

Returns

Returns a (output, jvp_fn) tuple containing the output of func applied to primals and a function that computes the jvp of func evaluated at primals.

Return type

Tuple[Any, Callable]

linearize is useful if jvp is to be computed multiple times at primals. However, to achieve this, linearize saves intermediate computation and has higher memory requirements than directly applying jvp. So, if all the tangents are known, it maybe more efficient to compute vmap(jvp) instead of using linearize.

Note

linearize evaluates func twice. Please file an issue for an implementation with a single evaluation.

Example::
>>> import torch
>>> from torch.func import linearize
>>> def fn(x):
...     return x.sin()
...
>>> output, jvp_fn = linearize(fn, torch.zeros(3, 3))
>>> jvp_fn(torch.ones(3, 3))
tensor([[1., 1., 1.],
        [1., 1., 1.],
        [1., 1., 1.]])
>>>

Docs

Access comprehensive developer documentation for PyTorch

View Docs

Tutorials

Get in-depth tutorials for beginners and advanced developers

View Tutorials

Resources

Find development resources and get your questions answered

View Resources