functorch¶
functorch is JAX-like composable function transforms for PyTorch.
It aims to provide composable vmap and grad transforms that work with PyTorch modules and PyTorch autograd with good eager-mode performance.
This library is currently under heavy development - if you have suggestions on the API or use-cases you’d like to be covered, please open an github issue or reach out. We’d love to hear about how you’re using the library.
Why composable function transforms?¶
There are a number of use cases that are tricky to do in PyTorch today:
computing per-sample-gradients (or other per-sample quantities)
running ensembles of models on a single machine
efficiently batching together tasks in the inner-loop of MAML
efficiently computing Jacobians and Hessians
efficiently computing batched Jacobians and Hessians
Composing vmap()
, grad()
, and vjp()
transforms allows us to express the above without designing a separate subsystem for each.
This idea of composable function transforms comes from the JAX framework.
Read More¶
For a whirlwind tour of how to use the transforms, please check out this section in our README. For installation instructions or the API reference, please check below.