functorch API Reference

Function Transforms

grad

grad operator helps computing gradients of func with respect to the input(s) specified by argnums.

grad_and_value

Returns a function to compute a tuple of the gradient and primal, or forward, computation.

jacrev

Computes the Jacobian of f with respect to the arg(s) at index argnum using reverse mode autodiff

vmap

vmap is the vectorizing map; vmap(func) returns a new function that maps func over some dimension of the inputs.

vjp

Standing for the vector-Jacobian product, returns a tuple containing the results of f applied to primals and a function that, when given cotangents, computes the reverse-mode Jacobian of f with respect to primals times cotangents.

Utilities for working with torch.nn.Modules

In general, you can transform over a function that calls a torch.nn.Module. For example, the following is an example of computing a jacobian of a function that takes three values and returns three values:

model = torch.nn.Linear(3, 3)

def f(x):
    return model(x)

x = torch.randn(3)
jacobian = jacrev(f)(x)
assert jacobian.shape == (3, 3)

However, if you want to do something like compute a jacobian over the parameters of the model, then there needs to be a way to construct a function where the parameters are the inputs to the function. That’s what make_functional() and make_functional_with_buffers() are for: given a torch.nn.Module, these return a new function that accepts parameters and the inputs to the Module’s forward pass.

make_functional

Given a torch.nn.Module, make_functional() extracts the state (params) and returns a functional version of the model, func.

make_functional_with_buffers

Given a torch.nn.Module, make_functional_with_buffers extracts the state (params and buffers) and returns a functional version of the model func that can be invoked like a function.

combine_state_for_ensemble

Prepares a list of torch.nn.Modules for ensembling with vmap().