Migrating from functorch to torch.func
======================================

torch.func, previously known as "functorch", is
`JAX-like <https://github.com/google/jax>`_ composable function transforms for PyTorch.

functorch started as an out-of-tree library over at
the `pytorch/functorch <https://github.com/pytorch/functorch>`_ repository.
Our goal has always been to upstream functorch directly into PyTorch and provide
it as a core PyTorch library.

As the final step of the upstream, we've decided to migrate from being a top level package
(``functorch``) to being a part of PyTorch to reflect how the function transforms are
integrated directly into PyTorch core. As of PyTorch 2.0, we are deprecating
``import functorch`` and ask that users migrate to the newest APIs, which we
will maintain going forward. ``import functorch`` will be kept around to maintain
backwards compatibility for a couple of releases.

function transforms
-------------------

The following APIs are a drop-in replacement for the following
`functorch APIs <https://pytorch.org/functorch/1.13/functorch.html>`_.
They are fully backwards compatible.


==============================  =======================================
functorch API                    PyTorch API (as of PyTorch 2.0)
==============================  =======================================
functorch.vmap                  :func:`torch.vmap` or :func:`torch.func.vmap`
functorch.grad                  :func:`torch.func.grad`
functorch.vjp                   :func:`torch.func.vjp`
functorch.jvp                   :func:`torch.func.jvp`
functorch.jacrev                :func:`torch.func.jacrev`
functorch.jacfwd                :func:`torch.func.jacfwd`
functorch.hessian               :func:`torch.func.hessian`
functorch.functionalize         :func:`torch.func.functionalize`
==============================  =======================================

Furthermore, if you are using torch.autograd.functional APIs, please try out
the :mod:`torch.func` equivalents instead. :mod:`torch.func` function
transforms are more composable and more performant in many cases.

=========================================== =======================================
torch.autograd.functional API               torch.func API (as of PyTorch 2.0)
=========================================== =======================================
:func:`torch.autograd.functional.vjp`       :func:`torch.func.grad` or :func:`torch.func.vjp`
:func:`torch.autograd.functional.jvp`       :func:`torch.func.jvp`
:func:`torch.autograd.functional.jacobian`  :func:`torch.func.jacrev` or :func:`torch.func.jacfwd`
:func:`torch.autograd.functional.hessian`   :func:`torch.func.hessian`
=========================================== =======================================

NN module utilities
-------------------

We've changed the APIs to apply function transforms over NN modules to make them
fit better into the PyTorch design philosophy. The new API is different, so
please read this section carefully.

functorch.make_functional
^^^^^^^^^^^^^^^^^^^^^^^^^

:func:`torch.func.functional_call` is the replacement for
`functorch.make_functional <https://pytorch.org/functorch/1.13/generated/functorch.make_functional.html#functorch.make_functional>`_
and
`functorch.make_functional_with_buffers <https://pytorch.org/functorch/1.13/generated/functorch.make_functional_with_buffers.html#functorch.make_functional_with_buffers>`_.
However, it is not a drop-in replacement.

If you're in a hurry, you can use
`helper functions in this gist <https://gist.github.com/zou3519/7769506acc899d83ef1464e28f22e6cf>`_
that emulate the behavior of functorch.make_functional and functorch.make_functional_with_buffers.
We recommend using :func:`torch.func.functional_call` directly because it is a more explicit
and flexible API.

Concretely, functorch.make_functional returns a functional module and parameters.
The functional module accepts parameters and inputs to the model as arguments.
:func:`torch.func.functional_call` allows one to call the forward pass of an existing
module using new parameters and buffers and inputs.

Here's an example of how to compute gradients of parameters of a model using functorch
vs :mod:`torch.func`::

    # ---------------
    # using functorch
    # ---------------
    import torch
    import functorch
    inputs = torch.randn(64, 3)
    targets = torch.randn(64, 3)
    model = torch.nn.Linear(3, 3)

    fmodel, params = functorch.make_functional(model)

    def compute_loss(params, inputs, targets):
        prediction = fmodel(params, inputs)
        return torch.nn.functional.mse_loss(prediction, targets)

    grads = functorch.grad(compute_loss)(params, inputs, targets)

    # ------------------------------------
    # using torch.func (as of PyTorch 2.0)
    # ------------------------------------
    import torch
    inputs = torch.randn(64, 3)
    targets = torch.randn(64, 3)
    model = torch.nn.Linear(3, 3)

    params = dict(model.named_parameters())

    def compute_loss(params, inputs, targets):
        prediction = torch.func.functional_call(model, params, (inputs,))
        return torch.nn.functional.mse_loss(prediction, targets)

    grads = torch.func.grad(compute_loss)(params, inputs, targets)

And here's an example of how to compute jacobians of model parameters::

    # ---------------
    # using functorch
    # ---------------
    import torch
    import functorch
    inputs = torch.randn(64, 3)
    model = torch.nn.Linear(3, 3)

    fmodel, params = functorch.make_functional(model)
    jacobians = functorch.jacrev(fmodel)(params, inputs)

    # ------------------------------------
    # using torch.func (as of PyTorch 2.0)
    # ------------------------------------
    import torch
    from torch.func import jacrev, functional_call
    inputs = torch.randn(64, 3)
    model = torch.nn.Linear(3, 3)

    params = dict(model.named_parameters())
    # jacrev computes jacobians of argnums=0 by default.
    # We set it to 1 to compute jacobians of params
    jacobians = jacrev(functional_call, argnums=1)(model, params, (inputs,))

Note that it is important for memory consumption that you should only carry
around a single copy of your parameters. ``model.named_parameters()`` does not copy
the parameters. If in your model training you update the parameters of the model
in-place, then the ``nn.Module`` that is your model has the single copy of the
parameters and everything is OK.

However, if you want to carry your parameters around in a dictionary and update
them out-of-place, then there are two copies of parameters: the one in the
dictionary and the one in the ``model``. In this case, you should change
``model`` to not hold memory by converting it to the meta device via
``model.to('meta')``.

functorch.combine_state_for_ensemble
^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^

Please use :func:`torch.func.stack_module_state` instead of
`functorch.combine_state_for_ensemble <https://pytorch.org/functorch/1.13/generated/functorch.combine_state_for_ensemble.html>`_
:func:`torch.func.stack_module_state` returns two dictionaries, one of stacked parameters, and
one of stacked buffers, that can then be used with :func:`torch.vmap` and :func:`torch.func.functional_call`
for ensembling.

For example, here is an example of how to ensemble over a very simple model::

    import torch
    num_models = 5
    batch_size = 64
    in_features, out_features = 3, 3
    models = [torch.nn.Linear(in_features, out_features) for i in range(num_models)]
    data = torch.randn(batch_size, 3)

    # ---------------
    # using functorch
    # ---------------
    import functorch
    fmodel, params, buffers = functorch.combine_state_for_ensemble(models)
    output = functorch.vmap(fmodel, (0, 0, None))(params, buffers, data)
    assert output.shape == (num_models, batch_size, out_features)

    # ------------------------------------
    # using torch.func (as of PyTorch 2.0)
    # ------------------------------------
    import copy

    # Construct a version of the model with no memory by putting the Tensors on
    # the meta device.
    base_model = copy.deepcopy(models[0])
    base_model.to('meta')

    params, buffers = torch.func.stack_module_state(models)

    # It is possible to vmap directly over torch.func.functional_call,
    # but wrapping it in a function makes it clearer what is going on.
    def call_single_model(params, buffers, data):
        return torch.func.functional_call(base_model, (params, buffers), (data,))

    output = torch.vmap(call_single_model, (0, 0, None))(params, buffers, data)
    assert output.shape == (num_models, batch_size, out_features)


functorch.compile
-----------------

We are no longer supporting functorch.compile (also known as AOTAutograd)
as a frontend for compilation in PyTorch; we have integrated AOTAutograd
into PyTorch's compilation story. If you are a user, please use
:func:`torch.compile` instead.