functorch, like JAX, has restrictions around what can be transformed. In general, JAX’s limitations are that transforms only work with pure functions: that is, functions where the output is completely determined by the input and that do not involve side effects (like mutation).
We have a similar guarantee: our transforms work well with pure functions. However, we do support certain in-place operations. On one hand, writing code compatible with functorch transforms may involve changing how you write PyTorch code, on the other hand, you may find that our transforms let you express things that were previously difficult to express in PyTorch.
All functorch transforms share a limitation in that a function should not assign to global variables. Instead, all outputs to a function must be returned from the function. This restriction comes from how functorch is implemented: each transform wraps Tensor inputs in special functorch Tensor subclasses that facilitate the transform.
So, instead of the following:
import torch from functorch import grad # Don't do this intermediate = None def f(x): global intermediate intermediate = x.sin() z = intermediate.sin() return z x = torch.randn() grad_x = grad(f)(x)
f to return
def f(x): intermediate = x.sin() z = intermediate.sin() return z, intermediate grad_x, intermediate = grad(f, has_aux=True)(x)
vmap() is our most restrictive transform.
The grad-related transforms (
jvp()) do not
have these limitations.
hessian(), which is
jacfwd()) is a composition of
jvp() so it also has these limitations.
vmap(func) is a transform that returns a function that maps
some new dimension of each input Tensor. The mental model for vmap is that it is
like running a for-loop: for pure functions (i.e. in the absence of side
vmap(f)(x) is equivalent to:
torch.stack([f(x_i) for x_i in x.unbind(0)])
Mutation: Arbitrary mutation of Python data structures¶
In the presence of side effects,
vmap() no longer acts like it is running
a for-loop. For example, the following function:
def f(x, list): list.pop() print("hello!") return x.sum(0) x = torch.randn(3, 1) lst = [0, 1, 2, 3] result = vmap(f, in_dims=(0, None))(x, lst)
will print “hello!” once and pop only one element from
vmap() executes f a single time, so all side effects only happen once.
This is a consequence of how vmap is implemented. functorch has a special,
internal BatchedTensor class.
vmap(f)(*inputs) takes all Tensor inputs,
turns them into BatchedTensors, and calls
BatchedTensor overrides the PyTorch API to produce batched (i.e. vectorized)
behavior for each PyTorch operator.
Mutation: in-place PyTorch Operations¶
vmap() will raise an error if it encounters an unsupported PyTorch
in-place operation and it will succeed otherwise. Unsupported operations
are those that would cause a Tensor with more elements to be written to a
Tensor with fewer elements. Here’s an example of how this can occur:
def f(x, y): x.add_(y) return x x = torch.randn(1) y = torch.randn(3) # Raises an error because `y` has fewer elements than `x`. vmap(f, in_dims=(None, 0))(x, y)
x is a Tensor with one element,
y is a Tensor with three elements.
x + y has three elements (due to broadcasting), but attempting to write
three elements back into
x, which only has one element, raises an error
due to attempting to write three elements into a Tensor with a single element.
There is no problem if the Tensor being written to has the same number of elements (or more):
def f(x, y): x.add_(y) return x x = torch.randn(3) y = torch.randn(3) expected = x + y # Does not raise an error because x and y have the same number of elements. vmap(f, in_dims=(0, 0))(x, y) assert torch.allclose(x, expected)
Data-dependent Python control flow¶
We don’t yet support
vmap over data-dependent control flow. Data-dependent
control flow is when the condition of an if-statement, while-loop, or
for-loop is a Tensor that is being
vmap’ed over. For example, the
following will raise an error message:
def relu(x): if x > 0: return x return 0 x = torch.randn(3) vmap(relu)(x)
However, any control flow that is not dependent on the values in
tensors will work:
def custom_dot(x): if x.dim() == 1: return torch.dot(x, x) return (x * x).sum() x = torch.randn(3) vmap(custom_dot)(x)
JAX supports transforming over
data-dependent control flow
using special control flow operators (e.g.
We’re investigating adding equivalents of those to functorch
(open an issue on GitHub to voice your support!).