{ "cells": [ { "cell_type": "markdown", "id": "b687b169-ec83-493d-a7c5-f8c6cd402ea3", "metadata": {}, "source": [ "# Neural Tangent Kernels\n", "\n", "The neural tangent kernel (NTK) is a kernel that describes [how a neural network evolves during training](https://en.wikipedia.org/wiki/Neural_tangent_kernel). There has been a lot of research around it [in recent years](https://arxiv.org/abs/1806.07572). This tutorial, inspired by the implementation of [NTKs in JAX](https://github.com/google/neural-tangents), demonstrates how to easily compute this quantity using functorch." ] }, { "cell_type": "markdown", "id": "77f41c65-f070-4b60-b3d0-1c8f56ed4f64", "metadata": {}, "source": [ "## Setup\n", "\n", "First, some setup. Let's define a simple CNN that we wish to compute the NTK of." ] }, { "cell_type": "code", "execution_count": 1, "id": "855fa70b-5b63-4973-94df-41be57ab6ecf", "metadata": {}, "outputs": [], "source": [ "import torch\n", "import torch.nn as nn\n", "from functorch import make_functional, vmap, vjp, jvp, jacrev\n", "device = 'cuda'\n", "\n", "class CNN(nn.Module):\n", " def __init__(self):\n", " super(CNN, self).__init__()\n", " self.conv1 = nn.Conv2d(3, 32, (3, 3))\n", " self.conv2 = nn.Conv2d(32, 32, (3, 3))\n", " self.conv3 = nn.Conv2d(32, 32, (3, 3))\n", " self.fc = nn.Linear(21632, 10)\n", " \n", " def forward(self, x):\n", " x = self.conv1(x)\n", " x = x.relu()\n", " x = self.conv2(x)\n", " x = x.relu()\n", " x = self.conv3(x)\n", " x = x.flatten(1)\n", " x = self.fc(x)\n", " return x" ] }, { "cell_type": "markdown", "id": "52c600e9-207a-41ec-93b4-5d940827bda0", "metadata": {}, "source": [ "And let's generate some random data" ] }, { "cell_type": "code", "execution_count": 2, "id": "0001a907-f5c9-4532-9ee9-2e94b8487d08", "metadata": {}, "outputs": [], "source": [ "x_train = torch.randn(20, 3, 32, 32, device=device)\n", "x_test = torch.randn(5, 3, 32, 32, device=device)" ] }, { "cell_type": "markdown", "id": "8af210fe-9613-48ee-a96c-d0836458b0f1", "metadata": {}, "source": [ "## Create a function version of the model\n", "\n", "functorch transforms operate on functions. In particular, to compute the NTK, we will need a function that accepts the parameters of the model and a single input (as opposed to a batch of inputs!) and returns a single output.\n", "\n", "We'll use functorch's make_functional to accomplish the first step. If your module has buffers, you'll want to use make_functional_with_buffers instead." ] }, { "cell_type": "code", "execution_count": 3, "id": "e6b4bb59-bdde-46cd-8a28-7fd00a37a387", "metadata": {}, "outputs": [], "source": [ "net = CNN().to(device)\n", "fnet, params = make_functional(net)" ] }, { "cell_type": "markdown", "id": "319276a4-da45-499a-af47-0677107559b6", "metadata": {}, "source": [ "Keep in mind that the model was originally written to accept a batch of input data points. In our CNN example, there are no inter-batch operations. That is, each data point in the batch is independent of other data points. With this assumption in mind, we can easily generate a function that evaluates the model on a single data point:" ] }, { "cell_type": "code", "execution_count": 4, "id": "0b8b4021-eb10-4a50-9d99-3817cb0ce4cc", "metadata": {}, "outputs": [], "source": [ "def fnet_single(params, x):\n", " return fnet(params, x.unsqueeze(0)).squeeze(0)" ] }, { "cell_type": "markdown", "id": "62bc6b5a-31fa-411e-8069-e6c1f6d05248", "metadata": {}, "source": [ "## Compute the NTK: method 1\n", "\n", "We're ready to compute the empirical NTK. The empirical NTK for two data points `x1` and `x2` is defined as an inner product between the Jacobian of the model evaluated at `x1` and the Jacobian of the model evaluated at `x2`:\n", "\n", "$$J_{net}(x1) \\cdot J_{net}^T(x2)$$\n", "\n", "In the batched case where `x1` is a batch of data points and `x2` is a batch of data points, then we want the inner product between the Jacobians of all combinations of data points from `x1` and `x2`. Here's how to compute the NTK in the batched case:" ] }, { "cell_type": "code", "execution_count": 5, "id": "99a38a4b-64d3-4e13-bd63-2d71e8dd6840", "metadata": {}, "outputs": [], "source": [ "def empirical_ntk(fnet_single, params, x1, x2):\n", " # Compute J(x1)\n", " jac1 = vmap(jacrev(fnet_single), (None, 0))(params, x1)\n", " jac1 = [j.flatten(2) for j in jac1]\n", " \n", " # Compute J(x2)\n", " jac2 = vmap(jacrev(fnet_single), (None, 0))(params, x2)\n", " jac2 = [j.flatten(2) for j in jac2]\n", " \n", " # Compute J(x1) @ J(x2).T\n", " result = torch.stack([torch.einsum('Naf,Mbf->NMab', j1, j2) for j1, j2 in zip(jac1, jac2)])\n", " result = result.sum(0)\n", " return result" ] }, { "cell_type": "code", "execution_count": 6, "id": "cbf54d2b-c4bc-46bd-9e55-e1471d639a4e", "metadata": {}, "outputs": [ { "name": "stdout", "output_type": "stream", "text": [ "torch.Size([20, 5, 10, 10])\n" ] } ], "source": [ "result = empirical_ntk(fnet_single, params, x_train, x_test)\n", "print(result.shape)" ] }, { "cell_type": "markdown", "id": "ea844f45-98fb-4cba-8056-644292b968ab", "metadata": {}, "source": [ "In some cases, you may only want the diagonal or the trace of this quantity, especially if you know beforehand that the network architecture results in an NTK where the non-diagonal elements can be approximated by zero. It's easy to adjust the above function to do that:" ] }, { "cell_type": "code", "execution_count": 7, "id": "aae760c9-e906-4fda-b490-1126a86b7e96", "metadata": {}, "outputs": [], "source": [ "def empirical_ntk(fnet_single, params, x1, x2, compute='full'):\n", " # Compute J(x1)\n", " jac1 = vmap(jacrev(fnet_single), (None, 0))(params, x1)\n", " jac1 = [j.flatten(2) for j in jac1]\n", " \n", " # Compute J(x2)\n", " jac2 = vmap(jacrev(fnet_single), (None, 0))(params, x2)\n", " jac2 = [j.flatten(2) for j in jac2]\n", " \n", " # Compute J(x1) @ J(x2).T\n", " einsum_expr = None\n", " if compute == 'full':\n", " einsum_expr = 'Naf,Mbf->NMab'\n", " elif compute == 'trace':\n", " einsum_expr = 'Naf,Maf->NM'\n", " elif compute == 'diagonal':\n", " einsum_expr = 'Naf,Maf->NMa'\n", " else:\n", " assert False\n", " \n", " result = torch.stack([torch.einsum(einsum_expr, j1, j2) for j1, j2 in zip(jac1, jac2)])\n", " result = result.sum(0)\n", " return result" ] }, { "cell_type": "code", "execution_count": 8, "id": "42d974f3-1f9d-4953-8677-5ee22cfc67eb", "metadata": {}, "outputs": [ { "name": "stdout", "output_type": "stream", "text": [ "torch.Size([20, 5])\n" ] } ], "source": [ "result = empirical_ntk(fnet_single, params, x_train, x_test, 'trace')\n", "print(result.shape)" ] }, { "cell_type": "markdown", "id": "6c931e5d-51d7-47b2-80ee-ddcd4aee6aaa", "metadata": {}, "source": [ "## Compute the NTK: method 2\n", "\n", "The next method we will discuss is a way to compute the NTK implicitly. This has different tradeoffs compared to the previous one and it is generally more efficient when your model has large parameters; we recommend trying out both methods to see which works better.\n", "\n", "Here's our definition of NTK:\n", "\n", "$$J_{net}(x1) \\cdot J_{net}^T(x2)$$\n", "\n", "The implicit computation reformulates the problem by adding an identity matrix and rearranging the matrix-multiplies:\n", "\n", "$$= J_{net}(x1) \\cdot J_{net}^T(x2) \\cdot I$$\n", "$$= (J_{net}(x1) \\cdot (J_{net}^T(x2) \\cdot I))$$\n", "\n", "- Let $vjps = (J_{net}^T(x2) \\cdot I)$. We can use a vector-Jacobian product to compute this.\n", "- Now, consider $J_{net}(x1) \\cdot vjps$. This is a Jacobian-vector product!\n", "\n", "This suggests that we can use a combination of reverse-mode AD (to compute the vector-Jacobian product) and forward-mode AD (to compute the Jacobian-vector product) to compute the NTK. Let's code that up:" ] }, { "cell_type": "code", "execution_count": 9, "id": "dc4b49d7-3096-45d5-a7a1-7032309a2613", "metadata": {}, "outputs": [], "source": [ "def empirical_ntk_implicit(func, params, x1, x2, compute='full'):\n", " def get_ntk(x1, x2):\n", " def func_x1(params):\n", " return func(params, x1)\n", "\n", " def func_x2(params):\n", " return func(params, x2)\n", "\n", " output, vjp_fn = vjp(func_x1, params)\n", "\n", " def get_ntk_slice(vec):\n", " # This computes vec @ J(x2).T\n", " # `vec` is some unit vector (a single slice of the Identity matrix)\n", " vjps = vjp_fn(vec)\n", " # This computes J(X1) @ vjps\n", " _, jvps = jvp(func_x2, (params,), vjps)\n", " return jvps\n", "\n", " # Here's our identity matrix\n", " basis = torch.eye(output.numel(), dtype=output.dtype, device=output.device).view(output.numel(), -1)\n", " return vmap(get_ntk_slice)(basis)\n", " \n", " # get_ntk(x1, x2) computes the NTK for a single data point x1, x2\n", " # Since the x1, x2 inputs to empirical_ntk_implicit are batched,\n", " # we actually wish to compute the NTK between every pair of data points\n", " # between {x1} and {x2}. That's what the vmaps here do.\n", " result = vmap(vmap(get_ntk, (None, 0)), (0, None))(x1, x2)\n", " \n", " if compute == 'full':\n", " return result\n", " if compute == 'trace':\n", " return torch.einsum('NMKK->NM')\n", " if compute == 'diagonal':\n", " return torch.einsum('NMKK->NMK')" ] }, { "cell_type": "code", "execution_count": 10, "id": "f750544f-9e48-47fe-9f9b-e1b8ae49b245", "metadata": {}, "outputs": [], "source": [ "result_implicit = empirical_ntk_implicit(fnet_single, params, x_test, x_train)\n", "result_explicit = empirical_ntk(fnet_single, params, x_test, x_train)\n", "assert torch.allclose(result_implicit, result_explicit, atol=1e-5)" ] }, { "cell_type": "markdown", "id": "84253466-971d-4475-999c-fe3de6bd25b5", "metadata": {}, "source": [ "Our code for `empirical_ntk_implicit` looks like a direct translation from the math above! This showcases the power of function transforms: good luck trying to write an efficient version of the above using stock PyTorch." ] } ], "metadata": { "kernelspec": { "display_name": "Python 3 (ipykernel)", "language": "python", "name": "python3" }, "language_info": { "codemirror_mode": { "name": "ipython", "version": 3 }, "file_extension": ".py", "mimetype": "text/x-python", "name": "python", "nbconvert_exporter": "python", "pygments_lexer": "ipython3", "version": "3.9.7" } }, "nbformat": 4, "nbformat_minor": 5 }