Shortcuts

Migrating from functorch to torch.func

torch.func, previously known as “functorch”, is JAX-like composable function transforms for PyTorch.

functorch started as an out-of-tree library over at the pytorch/functorch repository. Our goal has always been to upstream functorch directly into PyTorch and provide it as a core PyTorch library.

As the final step of the upstream, we’ve decided to migrate from being a top level package (functorch) to being a part of PyTorch to reflect how the function transforms are integrated directly into PyTorch core. As of PyTorch 2.0, we are deprecating import functorch and ask that users migrate to the newest APIs, which we will maintain going forward. import functorch will be kept around to maintain backwards compatibility for a couple of releases.

function transforms

The following APIs are a drop-in replacement for the following functorch APIs. They are fully backwards compatible.

functorch API

PyTorch API (as of PyTorch 2.0)

functorch.vmap

torch.vmap() or torch.func.vmap()

functorch.grad

torch.func.grad()

functorch.vjp

torch.func.vjp()

functorch.jvp

torch.func.jvp()

functorch.jacrev

torch.func.jacrev()

functorch.jacfwd

torch.func.jacfwd()

functorch.hessian

torch.func.hessian()

functorch.functionalize

torch.func.functionalize()

Furthermore, if you are using torch.autograd.functional APIs, please try out the torch.func equivalents instead. torch.func function transforms are more composable and more performant in many cases.

torch.autograd.functional API

torch.func API (as of PyTorch 2.0)

torch.autograd.functional.vjp()

torch.func.grad() or torch.func.vjp()

torch.autograd.functional.jvp()

torch.func.jvp()

torch.autograd.functional.jacobian()

torch.func.jacrev() or torch.func.jacfwd()

torch.autograd.functional.hessian()

torch.func.hessian()

NN module utilities

We’ve changed the APIs to apply function transforms over NN modules to make them fit better into the PyTorch design philosophy. The new API is different, so please read this section carefully.

functorch.make_functional

torch.func.functional_call() is the replacement for functorch.make_functional and functorch.make_functional_with_buffers. However, it is not a drop-in replacement.

If you’re in a hurry, you can use helper functions in this gist that emulate the behavior of functorch.make_functional and functorch.make_functional_with_buffers. We recommend using torch.func.functional_call() directly because it is a more explicit and flexible API.

Concretely, functorch.make_functional returns a functional module and parameters. The functional module accepts parameters and inputs to the model as arguments. torch.func.functional_call() allows one to call the forward pass of an existing module using new parameters and buffers and inputs.

Here’s an example of how to compute gradients of parameters of a model using functorch vs torch.func:

# ---------------
# using functorch
# ---------------
import torch
import functorch
inputs = torch.randn(64, 3)
targets = torch.randn(64, 3)
model = torch.nn.Linear(3, 3)

fmodel, params = functorch.make_functional(model)

def compute_loss(params, inputs, targets):
    prediction = fmodel(params, inputs)
    return torch.nn.functional.mse_loss(prediction, targets)

grads = functorch.grad(compute_loss)(params, inputs, targets)

# ------------------------------------
# using torch.func (as of PyTorch 2.0)
# ------------------------------------
import torch
inputs = torch.randn(64, 3)
targets = torch.randn(64, 3)
model = torch.nn.Linear(3, 3)

params = dict(model.named_parameters())

def compute_loss(params, inputs, targets):
    prediction = torch.func.functional_call(model, params, (inputs,))
    return torch.nn.functional.mse_loss(prediction, targets)

grads = torch.func.grad(compute_loss)(params, inputs, targets)

And here’s an example of how to compute jacobians of model parameters:

# ---------------
# using functorch
# ---------------
import torch
import functorch
inputs = torch.randn(64, 3)
model = torch.nn.Linear(3, 3)

fmodel, params = functorch.make_functional(model)
jacobians = functorch.jacrev(fmodel)(params, inputs)

# ------------------------------------
# using torch.func (as of PyTorch 2.0)
# ------------------------------------
import torch
from torch.func import jacrev, functional_call
inputs = torch.randn(64, 3)
model = torch.nn.Linear(3, 3)

params = dict(model.named_parameters())
# jacrev computes jacobians of argnums=0 by default.
# We set it to 1 to compute jacobians of params
jacobians = jacrev(functional_call, argnums=1)(model, params, (inputs,))

Note that it is important for memory consumption that you should only carry around a single copy of your parameters. model.named_parameters() does not copy the parameters. If in your model training you update the parameters of the model in-place, then the nn.Module that is your model has the single copy of the parameters and everything is OK.

However, if you want to carry your parameters around in a dictionary and update them out-of-place, then there are two copies of parameters: the one in the dictionary and the one in the model. In this case, you should change model to not hold memory by converting it to the meta device via model.to('meta').

functorch.combine_state_for_ensemble

Please use torch.func.stack_module_state() instead of functorch.combine_state_for_ensemble torch.func.stack_module_state() returns two dictionaries, one of stacked parameters, and one of stacked buffers, that can then be used with torch.vmap() and torch.func.functional_call() for ensembling.

For example, here is an example of how to ensemble over a very simple model:

import torch
num_models = 5
batch_size = 64
in_features, out_features = 3, 3
models = [torch.nn.Linear(in_features, out_features) for i in range(num_models)]
data = torch.randn(batch_size, 3)

# ---------------
# using functorch
# ---------------
import functorch
fmodel, params, buffers = functorch.combine_state_for_ensemble(models)
output = functorch.vmap(fmodel, (0, 0, None))(params, buffers, data)
assert output.shape == (num_models, batch_size, out_features)

# ------------------------------------
# using torch.func (as of PyTorch 2.0)
# ------------------------------------
import copy

# Construct a version of the model with no memory by putting the Tensors on
# the meta device.
base_model = copy.deepcopy(models[0])
base_model.to('meta')

params, buffers = torch.func.stack_module_state(models)

# It is possible to vmap directly over torch.func.functional_call,
# but wrapping it in a function makes it clearer what is going on.
def call_single_model(params, buffers, data):
    return torch.func.functional_call(base_model, (params, buffers), (data,))

output = torch.vmap(call_single_model, (0, 0, None))(params, buffers, data)
assert output.shape == (num_models, batch_size, out_features)

functorch.compile

We are no longer supporting functorch.compile (also known as AOTAutograd) as a frontend for compilation in PyTorch; we have integrated AOTAutograd into PyTorch’s compilation story. If you are a user, please use torch.compile() instead.

Docs

Access comprehensive developer documentation for PyTorch

View Docs

Tutorials

Get in-depth tutorials for beginners and advanced developers

View Tutorials

Resources

Find development resources and get your questions answered

View Resources