• Docs >
  • Neural Tangent Kernels
Shortcuts

Neural Tangent Kernels

The neural tangent kernel (NTK) is a kernel that describes how a neural network evolves during training. There has been a lot of research around it in recent years. This tutorial, inspired by the implementation of NTKs in JAX, demonstrates how to easily compute this quantity using functorch.

Setup

First, some setup. Let’s define a simple CNN that we wish to compute the NTK of.

import torch
import torch.nn as nn
from functorch import make_functional, vmap, vjp, jvp, jacrev
device = 'cuda'

class CNN(nn.Module):
    def __init__(self):
        super(CNN, self).__init__()
        self.conv1 = nn.Conv2d(3, 32, (3, 3))
        self.conv2 = nn.Conv2d(32, 32, (3, 3))
        self.conv3 = nn.Conv2d(32, 32, (3, 3))
        self.fc = nn.Linear(21632, 10)
        
    def forward(self, x):
        x = self.conv1(x)
        x = x.relu()
        x = self.conv2(x)
        x = x.relu()
        x = self.conv3(x)
        x = x.flatten(1)
        x = self.fc(x)
        return x

And let’s generate some random data

x_train = torch.randn(20, 3, 32, 32, device=device)
x_test = torch.randn(5, 3, 32, 32, device=device)

Create a function version of the model

functorch transforms operate on functions. In particular, to compute the NTK, we will need a function that accepts the parameters of the model and a single input (as opposed to a batch of inputs!) and returns a single output.

We’ll use functorch’s make_functional to accomplish the first step. If your module has buffers, you’ll want to use make_functional_with_buffers instead.

net = CNN().to(device)
fnet, params = make_functional(net)

Keep in mind that the model was originally written to accept a batch of input data points. In our CNN example, there are no inter-batch operations. That is, each data point in the batch is independent of other data points. With this assumption in mind, we can easily generate a function that evaluates the model on a single data point:

def fnet_single(params, x):
    return fnet(params, x.unsqueeze(0)).squeeze(0)

Compute the NTK: method 1

We’re ready to compute the empirical NTK. The empirical NTK for two data points x1 and x2 is defined as an inner product between the Jacobian of the model evaluated at x1 and the Jacobian of the model evaluated at x2:

\[J_{net}(x1) \cdot J_{net}^T(x2)\]

In the batched case where x1 is a batch of data points and x2 is a batch of data points, then we want the inner product between the Jacobians of all combinations of data points from x1 and x2. Here’s how to compute the NTK in the batched case:

def empirical_ntk(fnet_single, params, x1, x2):
    # Compute J(x1)
    jac1 = vmap(jacrev(fnet_single), (None, 0))(params, x1)
    jac1 = [j.flatten(2) for j in jac1]
    
    # Compute J(x2)
    jac2 = vmap(jacrev(fnet_single), (None, 0))(params, x2)
    jac2 = [j.flatten(2) for j in jac2]
    
    # Compute J(x1) @ J(x2).T
    result = torch.stack([torch.einsum('Naf,Mbf->NMab', j1, j2) for j1, j2 in zip(jac1, jac2)])
    result = result.sum(0)
    return result
result = empirical_ntk(fnet_single, params, x_train, x_test)
print(result.shape)
torch.Size([20, 5, 10, 10])

In some cases, you may only want the diagonal or the trace of this quantity, especially if you know beforehand that the network architecture results in an NTK where the non-diagonal elements can be approximated by zero. It’s easy to adjust the above function to do that:

def empirical_ntk(fnet_single, params, x1, x2, compute='full'):
    # Compute J(x1)
    jac1 = vmap(jacrev(fnet_single), (None, 0))(params, x1)
    jac1 = [j.flatten(2) for j in jac1]
    
    # Compute J(x2)
    jac2 = vmap(jacrev(fnet_single), (None, 0))(params, x2)
    jac2 = [j.flatten(2) for j in jac2]
    
    # Compute J(x1) @ J(x2).T
    einsum_expr = None
    if compute == 'full':
        einsum_expr = 'Naf,Mbf->NMab'
    elif compute == 'trace':
        einsum_expr = 'Naf,Maf->NM'
    elif compute == 'diagonal':
        einsum_expr = 'Naf,Maf->NMa'
    else:
        assert False
        
    result = torch.stack([torch.einsum(einsum_expr, j1, j2) for j1, j2 in zip(jac1, jac2)])
    result = result.sum(0)
    return result
result = empirical_ntk(fnet_single, params, x_train, x_test, 'trace')
print(result.shape)
torch.Size([20, 5])

Compute the NTK: method 2

The next method we will discuss is a way to compute the NTK implicitly. This has different tradeoffs compared to the previous one and it is generally more efficient when your model has large parameters; we recommend trying out both methods to see which works better.

Here’s our definition of NTK:

\[J_{net}(x1) \cdot J_{net}^T(x2)\]

The implicit computation reformulates the problem by adding an identity matrix and rearranging the matrix-multiplies:

\[= J_{net}(x1) \cdot J_{net}^T(x2) \cdot I\]
\[= (J_{net}(x1) \cdot (J_{net}^T(x2) \cdot I))\]
  • Let \(vjps = (J_{net}^T(x2) \cdot I)\). We can use a vector-Jacobian product to compute this.

  • Now, consider \(J_{net}(x1) \cdot vjps\). This is a Jacobian-vector product!

This suggests that we can use a combination of reverse-mode AD (to compute the vector-Jacobian product) and forward-mode AD (to compute the Jacobian-vector product) to compute the NTK. Let’s code that up:

def empirical_ntk_implicit(func, params, x1, x2, compute='full'):
    def get_ntk(x1, x2):
        def func_x1(params):
            return func(params, x1)

        def func_x2(params):
            return func(params, x2)

        output, vjp_fn = vjp(func_x1, params)

        def get_ntk_slice(vec):
            # This computes vec @ J(x2).T
            # `vec` is some unit vector (a single slice of the Identity matrix)
            vjps = vjp_fn(vec)
            # This computes J(X1) @ vjps
            _, jvps = jvp(func_x2, (params,), vjps)
            return jvps

        # Here's our identity matrix
        basis = torch.eye(output.numel(), dtype=output.dtype, device=output.device).view(output.numel(), -1)
        return vmap(get_ntk_slice)(basis)
        
    # get_ntk(x1, x2) computes the NTK for a single data point x1, x2
    # Since the x1, x2 inputs to empirical_ntk_implicit are batched,
    # we actually wish to compute the NTK between every pair of data points
    # between {x1} and {x2}. That's what the vmaps here do.
    result = vmap(vmap(get_ntk, (None, 0)), (0, None))(x1, x2)
    
    if compute == 'full':
        return result
    if compute == 'trace':
        return torch.einsum('NMKK->NM')
    if compute == 'diagonal':
        return torch.einsum('NMKK->NMK')
result_implicit = empirical_ntk_implicit(fnet_single, params, x_test, x_train)
result_explicit = empirical_ntk(fnet_single, params, x_test, x_train)
assert torch.allclose(result_implicit, result_explicit, atol=1e-5)

Our code for empirical_ntk_implicit looks like a direct translation from the math above! This showcases the power of function transforms: good luck trying to write an efficient version of the above using stock PyTorch.