# UX Limitations¶

functorch, like JAX, has restrictions around what can be transformed. In general, JAX’s limitations are that transforms only work with pure functions: that is, functions where the output is completely determined by the input and that do not involve side effects (like mutation).

We have a similar guarantee: our transforms work well with pure functions. However, we do support certain in-place operations. On one hand, writing code compatible with functorch transforms may involve changing how you write PyTorch code, on the other hand, you may find that our transforms let you express things that were previously difficult to express in PyTorch.

## General limitations¶

All functorch transforms share a limitation in that a function should not assign to global variables. Instead, all outputs to a function must be returned from the function. This restriction comes from how functorch is implemented: each transform wraps Tensor inputs in special functorch Tensor subclasses that facilitate the transform.

So, instead of the following:

```
import torch
from functorch import grad
# Don't do this
intermediate = None
def f(x):
global intermediate
intermediate = x.sin()
z = intermediate.sin()
return z
x = torch.randn([])
grad_x = grad(f)(x)
```

Please rewrite `f`

to return `intermediate`

:

```
def f(x):
intermediate = x.sin()
z = intermediate.sin()
return z, intermediate
grad_x, intermediate = grad(f, has_aux=True)(x)
```

## vmap limitations¶

Note

`vmap()`

is our most restrictive transform.
The grad-related transforms (`grad()`

, `vjp()`

, `jvp()`

) do not
have these limitations. `jacfwd()`

(and `hessian()`

, which is
implemented with `jacfwd()`

) is a composition of `vmap()`

and
`jvp()`

so it also has these limitations.

`vmap(func)`

is a transform that returns a function that maps `func`

over
some new dimension of each input Tensor. The mental model for vmap is that it is
like running a for-loop: for pure functions (i.e. in the absence of side
effects), `vmap(f)(x)`

is equivalent to:

```
torch.stack([f(x_i) for x_i in x.unbind(0)])
```

### Mutation: Arbitrary mutation of Python data structures¶

In the presence of side effects, `vmap()`

no longer acts like it is running
a for-loop. For example, the following function:

```
def f(x, list):
list.pop()
print("hello!")
return x.sum(0)
x = torch.randn(3, 1)
lst = [0, 1, 2, 3]
result = vmap(f, in_dims=(0, None))(x, lst)
```

will print “hello!” once and pop only one element from `lst`

.

`vmap()`

executes f a single time, so all side effects only happen once.

This is a consequence of how vmap is implemented. functorch has a special,
internal BatchedTensor class. `vmap(f)(*inputs)`

takes all Tensor inputs,
turns them into BatchedTensors, and calls `f(*batched_tensor_inputs)`

.
BatchedTensor overrides the PyTorch API to produce batched (i.e. vectorized)
behavior for each PyTorch operator.

### Mutation: in-place PyTorch Operations¶

`vmap()`

will raise an error if it encounters an unsupported PyTorch
in-place operation and it will succeed otherwise. Unsupported operations
are those that would cause a Tensor with more memory to be written to a
Tensor with less memory. Here’s an example of how this can occur:

```
def f(x, y):
x.add_(y)
return x
x = torch.randn(1)
y = torch.randn(3)
# Raises an error
vmap(f, in_dims=(None, 0))(x, y)
```

`x`

is a Tensor with one element, `y`

is a Tensor with three elements.
`x + y`

has three elements (due to broadcasting), but attempting to write
three elements back into `x`

, which only has one element, raises an error
due to there not being enough memory to hold three elements.

There is no problem if there is sufficient memory for the in-place operations to occur:

```
def f(x, y):
x.add_(y)
return x
x = torch.randn(3)
y = torch.randn(3)
expected = x + y
# Raises an error
vmap(f, in_dims=(0, 0))(x, y)
assert torch.allclose(x, expected)
```

### Data-dependent Python control flow¶

We don’t yet support `vmap`

over data-dependent control flow. Data-dependent
control flow is when the condition of an if-statement, while-loop, or
for-loop is a Tensor that is being `vmap`

’ed over. For example, the
following will raise an error message:

```
def relu(x):
if x > 0:
return x
return 0
x = torch.randn(3)
vmap(relu)(x)
```

However, any control flow that is not dependent on the values in `vmap`

’ed
tensors will work:

```
def custom_dot(x):
if x.dim() == 1:
return torch.dot(x, x)
return (x * x).sum()
x = torch.randn(3)
vmap(custom_dot)(x)
```

JAX supports transforming over
data-dependent control flow
using special control flow operators (e.g. `jax.lax.cond`

, `jax.lax.while_loop`

).
We’re investigating adding equivalents of those to functorch
(open an issue on GitHub to voice your support!).