functorch.make_functional¶
-
functorch.
make_functional
(model, disable_autograd_tracking=False) → func, params[source]¶ Given a
torch.nn.Module
,make_functional()
extracts the state (params) and returns a functional version of the model,func
. This makes it so that it is possible use transforms over the parameters ofmodel
.func
can be invoked as follows:import torch import torch.nn as nn from functorch import make_functional x = torch.randn(4, 3) model = nn.Linear(3, 3) func, params = make_functional(model) func(params, x)
And here is an example of applying the grad transform over the parameters of a model.
import torch import torch.nn as nn from functorch import make_functional, grad x = torch.randn(4, 3) t = torch.randn(4, 3) model = nn.Linear(3, 3) func, params = make_functional(model) def compute_loss(params, x, t): y = func(params, x) return nn.functional.mse_loss(y, t) grad_weights = grad(compute_loss)(params, x, t)
If the model has any buffers, please use
make_functional_with_buffers()
instead.- Parameters
model (torch.nn.Module) – Input model.
disable_autograd_tracking (bool) – Flag to disable gradients tracking for output parameters. The returned params are unrelated to the set of params from the original model. If False (default), the params will have
requires_grad=True
on them (aka they will be trackable with regular PyTorch autograd), matching the requires_grad-ness of the params from the original model. Otherwise, the returned params will haverequires_grad=False
. Default, False. If you plan on using regular PyTorch autograd (e.g., if you want to call.backward()
ortorch.autograd.grad()
, then setdisable_autograd_tracking=False
. Otherwise, if you’re only planning on using functorch’s gradient transforms, then please setdisable_autograd_tracking=True
to avoid unnecessarily tracking history with PyTorch autograd.
Warning
We’ve integrated functorch into PyTorch. As the final step of the integration, functorch.make_functional is deprecated as of PyTorch 2.0 and will be deleted in a future version of PyTorch >= 2.3. Please use torch.func.functional_call instead; see the PyTorch 2.0 release notes and/or the torch.func migration guide for more details https://pytorch.org/docs/master/func.migrating.html