Shortcuts

functorch

functorch is JAX-like composable function transforms for PyTorch.

It aims to provide composable vmap and grad transforms that work with PyTorch modules and PyTorch autograd with good eager-mode performance.

Note

This library is currently in beta. What this means is that the features generally work (unless otherwise documented) and we (the PyTorch team) are committed to bringing this library forward. However, the APIs may change under user feedback and we don’t have full coverage over PyTorch operations.

If you have suggestions on the API or use-cases you’d like to be covered, please open an github issue or reach out. We’d love to hear about how you’re using the library.

Why composable function transforms?

There are a number of use cases that are tricky to do in PyTorch today:

  • computing per-sample-gradients (or other per-sample quantities)

  • running ensembles of models on a single machine

  • efficiently batching together tasks in the inner-loop of MAML

  • efficiently computing Jacobians and Hessians

  • efficiently computing batched Jacobians and Hessians

Composing vmap(), grad(), and vjp() transforms allows us to express the above without designing a separate subsystem for each. This idea of composable function transforms comes from the JAX framework.