# Neural Tangent Kernels¶

The neural tangent kernel (NTK) is a kernel that describes how a neural network evolves during training. There has been a lot of research around it in recent years. This tutorial, inspired by the implementation of NTKs in JAX, demonstrates how to easily compute this quantity using functorch.

## Setup¶

First, some setup. Let’s define a simple CNN that we wish to compute the NTK of.

```
import torch
import torch.nn as nn
from functorch import make_functional, vmap, vjp, jvp, jacrev
device = 'cuda'
class CNN(nn.Module):
def __init__(self):
super(CNN, self).__init__()
self.conv1 = nn.Conv2d(3, 32, (3, 3))
self.conv2 = nn.Conv2d(32, 32, (3, 3))
self.conv3 = nn.Conv2d(32, 32, (3, 3))
self.fc = nn.Linear(21632, 10)
def forward(self, x):
x = self.conv1(x)
x = x.relu()
x = self.conv2(x)
x = x.relu()
x = self.conv3(x)
x = x.flatten(1)
x = self.fc(x)
return x
```

And let’s generate some random data

```
x_train = torch.randn(20, 3, 32, 32, device=device)
x_test = torch.randn(5, 3, 32, 32, device=device)
```

## Create a function version of the model¶

functorch transforms operate on functions. In particular, to compute the NTK, we will need a function that accepts the parameters of the model and a single input (as opposed to a batch of inputs!) and returns a single output.

We’ll use functorch’s make_functional to accomplish the first step. If your module has buffers, you’ll want to use make_functional_with_buffers instead.

```
net = CNN().to(device)
fnet, params = make_functional(net)
```

Keep in mind that the model was originally written to accept a batch of input data points. In our CNN example, there are no inter-batch operations. That is, each data point in the batch is independent of other data points. With this assumption in mind, we can easily generate a function that evaluates the model on a single data point:

```
def fnet_single(params, x):
return fnet(params, x.unsqueeze(0)).squeeze(0)
```

## Compute the NTK: method 1¶

We’re ready to compute the empirical NTK. The empirical NTK for two data points `x1`

and `x2`

is defined as an inner product between the Jacobian of the model evaluated at `x1`

and the Jacobian of the model evaluated at `x2`

:

In the batched case where `x1`

is a batch of data points and `x2`

is a batch of data points, then we want the inner product between the Jacobians of all combinations of data points from `x1`

and `x2`

. Here’s how to compute the NTK in the batched case:

```
def empirical_ntk(fnet_single, params, x1, x2):
# Compute J(x1)
jac1 = vmap(jacrev(fnet_single), (None, 0))(params, x1)
jac1 = [j.flatten(2) for j in jac1]
# Compute J(x2)
jac2 = vmap(jacrev(fnet_single), (None, 0))(params, x2)
jac2 = [j.flatten(2) for j in jac2]
# Compute J(x1) @ J(x2).T
result = torch.stack([torch.einsum('Naf,Mbf->NMab', j1, j2) for j1, j2 in zip(jac1, jac2)])
result = result.sum(0)
return result
```

```
result = empirical_ntk(fnet_single, params, x_train, x_test)
print(result.shape)
```

```
torch.Size([20, 5, 10, 10])
```

In some cases, you may only want the diagonal or the trace of this quantity, especially if you know beforehand that the network architecture results in an NTK where the non-diagonal elements can be approximated by zero. It’s easy to adjust the above function to do that:

```
def empirical_ntk(fnet_single, params, x1, x2, compute='full'):
# Compute J(x1)
jac1 = vmap(jacrev(fnet_single), (None, 0))(params, x1)
jac1 = [j.flatten(2) for j in jac1]
# Compute J(x2)
jac2 = vmap(jacrev(fnet_single), (None, 0))(params, x2)
jac2 = [j.flatten(2) for j in jac2]
# Compute J(x1) @ J(x2).T
einsum_expr = None
if compute == 'full':
einsum_expr = 'Naf,Mbf->NMab'
elif compute == 'trace':
einsum_expr = 'Naf,Maf->NM'
elif compute == 'diagonal':
einsum_expr = 'Naf,Maf->NMa'
else:
assert False
result = torch.stack([torch.einsum(einsum_expr, j1, j2) for j1, j2 in zip(jac1, jac2)])
result = result.sum(0)
return result
```

```
result = empirical_ntk(fnet_single, params, x_train, x_test, 'trace')
print(result.shape)
```

```
torch.Size([20, 5])
```

## Compute the NTK: method 2¶

The next method we will discuss is a way to compute the NTK implicitly. This has different tradeoffs compared to the previous one and it is generally more efficient when your model has large parameters; we recommend trying out both methods to see which works better.

Here’s our definition of NTK:

The implicit computation reformulates the problem by adding an identity matrix and rearranging the matrix-multiplies:

Let \(vjps = (J_{net}^T(x2) \cdot I)\). We can use a vector-Jacobian product to compute this.

Now, consider \(J_{net}(x1) \cdot vjps\). This is a Jacobian-vector product!

This suggests that we can use a combination of reverse-mode AD (to compute the vector-Jacobian product) and forward-mode AD (to compute the Jacobian-vector product) to compute the NTK. Let’s code that up:

```
def empirical_ntk_implicit(func, params, x1, x2, compute='full'):
def get_ntk(x1, x2):
def func_x1(params):
return func(params, x1)
def func_x2(params):
return func(params, x2)
output, vjp_fn = vjp(func_x1, params)
def get_ntk_slice(vec):
# This computes vec @ J(x2).T
# `vec` is some unit vector (a single slice of the Identity matrix)
vjps = vjp_fn(vec)
# This computes J(X1) @ vjps
_, jvps = jvp(func_x2, (params,), vjps)
return jvps
# Here's our identity matrix
basis = torch.eye(output.numel(), dtype=output.dtype, device=output.device).view(output.numel(), -1)
return vmap(get_ntk_slice)(basis)
# get_ntk(x1, x2) computes the NTK for a single data point x1, x2
# Since the x1, x2 inputs to empirical_ntk_implicit are batched,
# we actually wish to compute the NTK between every pair of data points
# between {x1} and {x2}. That's what the vmaps here do.
result = vmap(vmap(get_ntk, (None, 0)), (0, None))(x1, x2)
if compute == 'full':
return result
if compute == 'trace':
return torch.einsum('NMKK->NM')
if compute == 'diagonal':
return torch.einsum('NMKK->NMK')
```

```
result_implicit = empirical_ntk_implicit(fnet_single, params, x_test, x_train)
result_explicit = empirical_ntk(fnet_single, params, x_test, x_train)
assert torch.allclose(result_implicit, result_explicit, atol=1e-5)
```

Our code for `empirical_ntk_implicit`

looks like a direct translation from the math above! This showcases the power of function transforms: good luck trying to write an efficient version of the above using stock PyTorch.