Shortcuts

functorch

Warning

We’ve integrated functorch into PyTorch. As the final step of the integration, the functorch APIs are deprecated as of PyTorch 2.0. Please use the torch.func APIs instead and see the migration guide and docs for more details.

Function Transforms

vmap

vmap is the vectorizing map; vmap(func) returns a new function that maps func over some dimension of the inputs.

grad

grad operator helps computing gradients of func with respect to the input(s) specified by argnums.

grad_and_value

Returns a function to compute a tuple of the gradient and primal, or forward, computation.

vjp

Standing for the vector-Jacobian product, returns a tuple containing the results of func applied to primals and a function that, when given cotangents, computes the reverse-mode Jacobian of func with respect to primals times cotangents.

jvp

Standing for the Jacobian-vector product, returns a tuple containing the output of func(*primals) and the “Jacobian of func evaluated at primals” times tangents.

jacrev

Computes the Jacobian of func with respect to the arg(s) at index argnum using reverse mode autodiff

jacfwd

Computes the Jacobian of func with respect to the arg(s) at index argnum using forward-mode autodiff

hessian

Computes the Hessian of func with respect to the arg(s) at index argnum via a forward-over-reverse strategy.

functionalize

functionalize is a transform that can be used to remove (intermediate) mutations and aliasing from a function, while preserving the function’s semantics.

Utilities for working with torch.nn.Modules

In general, you can transform over a function that calls a torch.nn.Module. For example, the following is an example of computing a jacobian of a function that takes three values and returns three values:

model = torch.nn.Linear(3, 3)

def f(x):
    return model(x)

x = torch.randn(3)
jacobian = jacrev(f)(x)
assert jacobian.shape == (3, 3)

However, if you want to do something like compute a jacobian over the parameters of the model, then there needs to be a way to construct a function where the parameters are the inputs to the function. That’s what make_functional() and make_functional_with_buffers() are for: given a torch.nn.Module, these return a new function that accepts parameters and the inputs to the Module’s forward pass.

make_functional

Given a torch.nn.Module, make_functional() extracts the state (params) and returns a functional version of the model, func.

make_functional_with_buffers

make_functional(model, disable_autograd_tracking=False) -> func, params

combine_state_for_ensemble

Prepares a list of torch.nn.Modules for ensembling with vmap().

If you’re looking for information on fixing Batch Norm modules, please follow the guidance here

Docs

Access comprehensive developer documentation for PyTorch

View Docs

Tutorials

Get in-depth tutorials for beginners and advanced developers

View Tutorials

Resources

Find development resources and get your questions answered

View Resources