functorch is JAX-like composable function transforms for PyTorch.


We’ve integrated functorch into PyTorch. As the final step of the integration, the functorch APIs are deprecated as of PyTorch 2.0. Please use the torch.func APIs instead and see the migration guide and docs for more details.

What are composable function transforms?

  • A “function transform” is a higher-order function that accepts a numerical function and returns a new function that computes a different quantity.

  • functorch has auto-differentiation transforms (grad(f) returns a function that computes the gradient of f), a vectorization/batching transform (vmap(f) returns a function that computes f over batches of inputs), and others.

  • These function transforms can compose with each other arbitrarily. For example, composing vmap(grad(f)) computes a quantity called per-sample-gradients that stock PyTorch cannot efficiently compute today.

Why composable function transforms?

There are a number of use cases that are tricky to do in PyTorch today:

  • computing per-sample-gradients (or other per-sample quantities)

  • running ensembles of models on a single machine

  • efficiently batching together tasks in the inner-loop of MAML

  • efficiently computing Jacobians and Hessians

  • efficiently computing batched Jacobians and Hessians

Composing vmap(), grad(), and vjp() transforms allows us to express the above without designing a separate subsystem for each. This idea of composable function transforms comes from the JAX framework.


Access comprehensive developer documentation for PyTorch

View Docs


Get in-depth tutorials for beginners and advanced developers

View Tutorials


Find development resources and get your questions answered

View Resources