functorch.make_functional_with_buffers¶
-
functorch.
make_functional_with_buffers
(model) → func, params, buffers[source]¶ Given a
torch.nn.Module
, make_functional_with_buffers extracts the state (params and buffers) and returns a functional version of the modelfunc
that can be invoked like a function.func
can be invoked as follows:import torch import torch.nn as nn from functorch import make_functional_with_buffers x = torch.randn(4, 3) model = nn.Linear(3, 3) func, params, buffers = make_functional_with_buffers(model) func(params, buffers, x)
And here is an example of applying the grad transform over the parameters of a model:
import torch import torch.nn as nn from functorch import make_functional_with_buffers, grad x = torch.randn(4, 3) t = torch.randn(4, 3) model = nn.Linear(3, 3) func, params, buffers = make_functional_with_buffers(model) def compute_loss(params, buffers, x, t): y = func(params, buffers, x) return nn.functional.mse_loss(y, t) grad_weights = grad(compute_loss)(params, buffers, x, t)