Shortcuts

UX Limitations

functorch, like JAX, has restrictions around what can be transformed. In general, JAX’s limitations are that transforms only work with pure functions: that is, functions where the output is completely determined by the input and that do not involve side effects (like mutation).

We have a similar guarantee: our transforms work well with pure functions. However, we do support certain in-place operations. On one hand, writing code compatible with functorch transforms may involve changing how you write PyTorch code, on the other hand, you may find that our transforms let you express things that were previously difficult to express in PyTorch.

General limitations

All functorch transforms share a limitation in that a function should not assign to global variables. Instead, all outputs to a function must be returned from the function. This restriction comes from how functorch is implemented: each transform wraps Tensor inputs in special functorch Tensor subclasses that facilitate the transform.

So, instead of the following:

import torch
from functorch import grad

# Don't do this
intermediate = None

def f(x):
  global intermediate
  intermediate = x.sin()
  z = intermediate.sin()
  return z

x = torch.randn([])
grad_x = grad(f)(x)

Please rewrite f to return intermediate:

def f(x):
  intermediate = x.sin()
  z = intermediate.sin()
  return z, intermediate

grad_x, intermediate = grad(f, has_aux=True)(x)

vmap limitations

Note

vmap() is our most restrictive transform. The grad-related transforms (grad(), vjp(), jvp()) do not have these limitations. jacfwd() (and hessian(), which is implemented with jacfwd()) is a composition of vmap() and jvp() so it also has these limitations.

vmap(func) is a transform that returns a function that maps func over some new dimension of each input Tensor. The mental model for vmap is that it is like running a for-loop: for pure functions (i.e. in the absence of side effects), vmap(f)(x) is equivalent to:

torch.stack([f(x_i) for x_i in x.unbind(0)])

Mutation: Arbitrary mutation of Python data structures

In the presence of side effects, vmap() no longer acts like it is running a for-loop. For example, the following function:

def f(x, list):
  list.pop()
  print("hello!")
  return x.sum(0)

x = torch.randn(3, 1)
lst = [0, 1, 2, 3]

result = vmap(f, in_dims=(0, None))(x, lst)

will print “hello!” once and pop only one element from lst.

vmap() executes f a single time, so all side effects only happen once.

This is a consequence of how vmap is implemented. functorch has a special, internal BatchedTensor class. vmap(f)(*inputs) takes all Tensor inputs, turns them into BatchedTensors, and calls f(*batched_tensor_inputs). BatchedTensor overrides the PyTorch API to produce batched (i.e. vectorized) behavior for each PyTorch operator.

Mutation: in-place PyTorch Operations

vmap() will raise an error if it encounters an unsupported PyTorch in-place operation and it will succeed otherwise. Unsupported operations are those that would cause a Tensor with more memory to be written to a Tensor with less memory. Here’s an example of how this can occur:

def f(x, y):
  x.add_(y)
return x

x = torch.randn(1)
y = torch.randn(3)

# Raises an error
vmap(f, in_dims=(None, 0))(x, y)

x is a Tensor with one element, y is a Tensor with three elements. x + y has three elements (due to broadcasting), but attempting to write three elements back into x, which only has one element, raises an error due to there not being enough memory to hold three elements.

There is no problem if there is sufficient memory for the in-place operations to occur:

def f(x, y):
  x.add_(y)
  return x

x = torch.randn(3)
y = torch.randn(3)
expected = x + y

# Raises an error
vmap(f, in_dims=(0, 0))(x, y)
assert torch.allclose(x, expected)

Data-dependent Python control flow

We don’t yet support vmap over data-dependent control flow. Data-dependent control flow is when the condition of an if-statement, while-loop, or for-loop is a Tensor that is being vmap’ed over. For example, the following will raise an error message:

def relu(x):
  if x > 0:
    return x
  return 0

x = torch.randn(3)
vmap(relu)(x)

However, any control flow that is not dependent on the values in vmap’ed tensors will work:

def custom_dot(x):
  if x.dim() == 1:
    return torch.dot(x, x)
  return (x * x).sum()

x = torch.randn(3)
vmap(custom_dot)(x)

JAX supports transforming over data-dependent control flow using special control flow operators (e.g. jax.lax.cond, jax.lax.while_loop). We’re investigating adding equivalents of those to functorch (open an issue on GitHub to voice your support!).