UX Limitations¶
functorch, like JAX, has restrictions around what can be transformed. In general, JAX’s limitations are that transforms only work with pure functions: that is, functions where the output is completely determined by the input and that do not involve side effects (like mutation).
We have a similar guarantee: our transforms work well with pure functions. However, we do support certain in-place operations. On one hand, writing code compatible with functorch transforms may involve changing how you write PyTorch code, on the other hand, you may find that our transforms let you express things that were previously difficult to express in PyTorch.
General limitations¶
All functorch transforms share a limitation in that a function should not assign to global variables. Instead, all outputs to a function must be returned from the function. This restriction comes from how functorch is implemented: each transform wraps Tensor inputs in special functorch Tensor subclasses that facilitate the transform.
So, instead of the following:
import torch
from functorch import grad
# Don't do this
intermediate = None
def f(x):
global intermediate
intermediate = x.sin()
z = intermediate.sin()
return z
x = torch.randn([])
grad_x = grad(f)(x)
Please rewrite f
to return intermediate
:
def f(x):
intermediate = x.sin()
z = intermediate.sin()
return z, intermediate
grad_x, intermediate = grad(f, has_aux=True)(x)
vmap limitations¶
Note
vmap()
is our most restrictive transform.
The grad-related transforms (grad()
, vjp()
, jvp()
) do not
have these limitations. jacfwd()
(and hessian()
, which is
implemented with jacfwd()
) is a composition of vmap()
and
jvp()
so it also has these limitations.
vmap(func)
is a transform that returns a function that maps func
over
some new dimension of each input Tensor. The mental model for vmap is that it is
like running a for-loop: for pure functions (i.e. in the absence of side
effects), vmap(f)(x)
is equivalent to:
torch.stack([f(x_i) for x_i in x.unbind(0)])
Mutation: Arbitrary mutation of Python data structures¶
In the presence of side effects, vmap()
no longer acts like it is running
a for-loop. For example, the following function:
def f(x, list):
list.pop()
print("hello!")
return x.sum(0)
x = torch.randn(3, 1)
lst = [0, 1, 2, 3]
result = vmap(f, in_dims=(0, None))(x, lst)
will print “hello!” once and pop only one element from lst
.
vmap()
executes f a single time, so all side effects only happen once.
This is a consequence of how vmap is implemented. functorch has a special,
internal BatchedTensor class. vmap(f)(*inputs)
takes all Tensor inputs,
turns them into BatchedTensors, and calls f(*batched_tensor_inputs)
.
BatchedTensor overrides the PyTorch API to produce batched (i.e. vectorized)
behavior for each PyTorch operator.
Mutation: in-place PyTorch Operations¶
vmap()
will raise an error if it encounters an unsupported PyTorch
in-place operation and it will succeed otherwise. Unsupported operations
are those that would cause a Tensor with more memory to be written to a
Tensor with less memory. Here’s an example of how this can occur:
def f(x, y):
x.add_(y)
return x
x = torch.randn(1)
y = torch.randn(3)
# Raises an error
vmap(f, in_dims=(None, 0))(x, y)
x
is a Tensor with one element, y
is a Tensor with three elements.
x + y
has three elements (due to broadcasting), but attempting to write
three elements back into x
, which only has one element, raises an error
due to there not being enough memory to hold three elements.
There is no problem if there is sufficient memory for the in-place operations to occur:
def f(x, y):
x.add_(y)
return x
x = torch.randn(3)
y = torch.randn(3)
expected = x + y
# Raises an error
vmap(f, in_dims=(0, 0))(x, y)
assert torch.allclose(x, expected)
Data-dependent Python control flow¶
We don’t yet support vmap
over data-dependent control flow. Data-dependent
control flow is when the condition of an if-statement, while-loop, or
for-loop is a Tensor that is being vmap
’ed over. For example, the
following will raise an error message:
def relu(x):
if x > 0:
return x
return 0
x = torch.randn(3)
vmap(relu)(x)
However, any control flow that is not dependent on the values in vmap
’ed
tensors will work:
def custom_dot(x):
if x.dim() == 1:
return torch.dot(x, x)
return (x * x).sum()
x = torch.randn(3)
vmap(custom_dot)(x)
JAX supports transforming over
data-dependent control flow
using special control flow operators (e.g. jax.lax.cond
, jax.lax.while_loop
).
We’re investigating adding equivalents of those to functorch
(open an issue on GitHub to voice your support!).