functorch.make_functional_with_buffers

functorch.make_functional_with_buffers(model)func, params, buffers[source]

Given a torch.nn.Module, make_functional_with_buffers extracts the state (params and buffers) and returns a functional version of the model func that can be invoked like a function.

func can be invoked as follows:

import torch
import torch.nn as nn
from functorch import make_functional_with_buffers

x = torch.randn(4, 3)
model = nn.Linear(3, 3)
func, params, buffers = make_functional_with_buffers(model)
func(params, buffers, x)

And here is an example of applying the grad transform over the parameters of a model:

import torch
import torch.nn as nn
from functorch import make_functional_with_buffers, grad

x = torch.randn(4, 3)
t = torch.randn(4, 3)
model = nn.Linear(3, 3)
func, params, buffers = make_functional_with_buffers(model)

def compute_loss(params, buffers, x, t):
    y = func(params, buffers, x)
    return nn.functional.mse_loss(y, t)

grad_weights = grad(compute_loss)(params, buffers, x, t)