functorch

functorch is JAX-like composable function transforms for PyTorch.

It aims to provide composable vmap and grad transforms that work with PyTorch modules and PyTorch autograd with good eager-mode performance.

This library is currently under heavy development - if you have suggestions on the API or use-cases you’d like to be covered, please open an github issue or reach out. We’d love to hear about how you’re using the library.

Why composable function transforms?

There are a number of use cases that are tricky to do in PyTorch today:

  • computing per-sample-gradients (or other per-sample quantities)

  • running ensembles of models on a single machine

  • efficiently batching together tasks in the inner-loop of MAML

  • efficiently computing Jacobians and Hessians

  • efficiently computing batched Jacobians and Hessians

Composing vmap(), grad(), and vjp() transforms allows us to express the above without designing a separate subsystem for each. This idea of composable function transforms comes from the JAX framework.

Read More

For a whirlwind tour of how to use the transforms, please check out this section in our README. For installation instructions or the API reference, please check below.