Whirlwind Tour¶
functorch is JAX-like composable function transforms for PyTorch. In this whirlwind tour, we’ll introduce all the functorch transforms.
Why composable function transforms?¶
There are a number of use cases that are tricky to do in PyTorch today:
computing per-sample-gradients (or other per-sample quantities)
running ensembles of models on a single machine
efficiently batching together tasks in the inner-loop of MAML
efficiently computing Jacobians and Hessians
efficiently computing batched Jacobians and Hessians
Composing vmap
, grad
, vjp
, and jvp
transforms allows us to express the above without designing a separate subsystem for each. This idea of composable function transforms comes from the JAX framework.
What are the transforms?¶
Right now, we support the following transforms:
grad
,vjp
,jvp
,jacrev
,jacfwd
,hessian
vmap
Furthermore, we have some utilities for working with PyTorch modules.
make_functional(model)
make_functional_with_buffers(model)
vmap¶
Note: vmap imposes restrictions on the code that it can be used on. For more details, please read its docstring.
vmap(func)(*inputs)
is a transform that adds a dimension to all Tensor operations in func
. vmap(func)
returns a new function that maps func
over some dimension (default: 0) of each Tensor in inputs.
vmap is useful for hiding batch dimensions: one can write a function func that runs on examples and then lift it to a function that can take batches of examples with vmap(func)
, leading to a simpler modeling experience:
import torch
from functorch import vmap
batch_size, feature_size = 3, 5
weights = torch.randn(feature_size, requires_grad=True)
def model(feature_vec):
# Very simple linear model with activation
assert feature_vec.dim() == 1
return feature_vec.dot(weights).relu()
examples = torch.randn(batch_size, feature_size)
result = vmap(model)(examples)
grad¶
grad(func)(*inputs)
assumes func
returns a single-element Tensor. By default, it computes the gradients of the output of func
w.r.t. to inputs[0]
.
from functorch import grad
x = torch.randn([])
cos_x = grad(lambda x: torch.sin(x))(x)
assert torch.allclose(cos_x, x.cos())
# Second-order gradients
neg_sin_x = grad(grad(lambda x: torch.sin(x)))(x)
assert torch.allclose(neg_sin_x, -x.sin())
When composed with vmap, grad can be used to compute per-sample-gradients:
from functorch import vmap
batch_size, feature_size = 3, 5
def model(weights,feature_vec):
# Very simple linear model with activation
assert feature_vec.dim() == 1
return feature_vec.dot(weights).relu()
def compute_loss(weights, example, target):
y = model(weights, example)
return ((y - target) ** 2).mean() # MSELoss
weights = torch.randn(feature_size, requires_grad=True)
examples = torch.randn(batch_size, feature_size)
targets = torch.randn(batch_size)
inputs = (weights,examples, targets)
grad_weight_per_example = vmap(grad(compute_loss), in_dims=(None, 0, 0))(*inputs)
vjp¶
The vjp
transform applies func
to inputs
and returns a new function that computes vjps given some cotangents
Tensors.
from functorch import vjp
inputs = torch.randn(3)
func = torch.sin
cotangents = (torch.randn(3),)
outputs, vjp_fn = vjp(func, inputs); vjps = vjp_fn(*cotangents)
jvp¶
The jvp
transforms computes Jacobian-vector-products and is also known as “forward-mode AD”. It is not a higher-order function unlike most other transforms, but it returns the outputs of func(inputs)
as well as the jvps.
from functorch import jvp
x = torch.randn(5)
y = torch.randn(5)
f = lambda x, y: (x * y)
_, output = jvp(f, (x, y), (torch.ones(5), torch.ones(5)))
assert torch.allclose(output, x + y)
jacrev, jacfwd, and hessian¶
The jacrev
transform returns a new function that takes in x
and returns the Jacobian of the function
with respect to x
using reverse-mode AD.
from functorch import jacrev
x = torch.randn(5)
jacobian = jacrev(torch.sin)(x)
expected = torch.diag(torch.cos(x))
assert torch.allclose(jacobian, expected)
Use jacrev
to compute the jacobian. This can be composed with vmap
to produce batched jacobians:
x = torch.randn(64, 5)
jacobian = vmap(jacrev(torch.sin))(x)
assert jacobian.shape == (64, 5, 5)
jacfwd
is a drop-in replacement for jacrev
that computes Jacobians using forward-mode AD:
from functorch import jacfwd
x = torch.randn(5)
jacobian = jacfwd(torch.sin)(x)
expected = torch.diag(torch.cos(x))
assert torch.allclose(jacobian, expected)
Composing jacrev
with itself or jacfwd
can produce hessians:
def f(x):
return x.sin().sum()
x = torch.randn(5)
hessian0 = jacrev(jacrev(f))(x)
hessian1 = jacfwd(jacrev(f))(x)
The hessian
is a convenience function that combines jacfwd
and jacrev
:
from functorch import hessian
def f(x):
return x.sin().sum()
x = torch.randn(5)
hess = hessian(f)(x)
Conclusion¶
Check out our other tutorials (in the left bar) for more detailed explanations of how to apply functorch transforms for various use cases. functorch
is very much a work in progress and we’d love to hear how you’re using it – we encourage you to start a conversation at our issues tracker to discuss your use case.