{ "cells": [ { "cell_type": "markdown", "id": "b687b169-ec83-493d-a7c5-f8c6cd402ea3", "metadata": {}, "source": [ "# Neural Tangent Kernels\n", "\n", "\n", " \"Open\n", "\n", "\n", "The neural tangent kernel (NTK) is a kernel that describes [how a neural network evolves during training](https://en.wikipedia.org/wiki/Neural_tangent_kernel). There has been a lot of research around it [in recent years](https://arxiv.org/abs/1806.07572). This tutorial, inspired by the implementation of [NTKs in JAX](https://github.com/google/neural-tangents) (see [Fast Finite Width Neural Tangent Kernel](https://arxiv.org/abs/2206.08720) for details), demonstrates how to easily compute this quantity using functorch." ] }, { "cell_type": "markdown", "id": "77f41c65-f070-4b60-b3d0-1c8f56ed4f64", "metadata": {}, "source": [ "## Setup\n", "\n", "First, some setup. Let's define a simple CNN that we wish to compute the NTK of." ] }, { "cell_type": "code", "execution_count": 1, "id": "855fa70b-5b63-4973-94df-41be57ab6ecf", "metadata": {}, "outputs": [], "source": [ "import torch\n", "import torch.nn as nn\n", "from functorch import make_functional, vmap, vjp, jvp, jacrev\n", "device = 'cuda'\n", "\n", "class CNN(nn.Module):\n", " def __init__(self):\n", " super().__init__()\n", " self.conv1 = nn.Conv2d(3, 32, (3, 3))\n", " self.conv2 = nn.Conv2d(32, 32, (3, 3))\n", " self.conv3 = nn.Conv2d(32, 32, (3, 3))\n", " self.fc = nn.Linear(21632, 10)\n", " \n", " def forward(self, x):\n", " x = self.conv1(x)\n", " x = x.relu()\n", " x = self.conv2(x)\n", " x = x.relu()\n", " x = self.conv3(x)\n", " x = x.flatten(1)\n", " x = self.fc(x)\n", " return x" ] }, { "cell_type": "markdown", "id": "52c600e9-207a-41ec-93b4-5d940827bda0", "metadata": {}, "source": [ "And let's generate some random data" ] }, { "cell_type": "code", "execution_count": 2, "id": "0001a907-f5c9-4532-9ee9-2e94b8487d08", "metadata": {}, "outputs": [], "source": [ "x_train = torch.randn(20, 3, 32, 32, device=device)\n", "x_test = torch.randn(5, 3, 32, 32, device=device)" ] }, { "cell_type": "markdown", "id": "8af210fe-9613-48ee-a96c-d0836458b0f1", "metadata": {}, "source": [ "## Create a function version of the model\n", "\n", "functorch transforms operate on functions. In particular, to compute the NTK, we will need a function that accepts the parameters of the model and a single input (as opposed to a batch of inputs!) and returns a single output.\n", "\n", "We'll use functorch's `make_functional` to accomplish the first step. If your module has buffers, you'll want to use `make_functional_with_buffers` instead." ] }, { "cell_type": "code", "execution_count": 3, "id": "e6b4bb59-bdde-46cd-8a28-7fd00a37a387", "metadata": {}, "outputs": [], "source": [ "net = CNN().to(device)\n", "fnet, params = make_functional(net)" ] }, { "cell_type": "markdown", "id": "319276a4-da45-499a-af47-0677107559b6", "metadata": {}, "source": [ "Keep in mind that the model was originally written to accept a batch of input data points. In our CNN example, there are no inter-batch operations. That is, each data point in the batch is independent of other data points. With this assumption in mind, we can easily generate a function that evaluates the model on a single data point:" ] }, { "cell_type": "code", "execution_count": 4, "id": "0b8b4021-eb10-4a50-9d99-3817cb0ce4cc", "metadata": {}, "outputs": [], "source": [ "def fnet_single(params, x):\n", " return fnet(params, x.unsqueeze(0)).squeeze(0)" ] }, { "cell_type": "markdown", "id": "62bc6b5a-31fa-411e-8069-e6c1f6d05248", "metadata": {}, "source": [ "## Compute the NTK: method 1 (Jacobian contraction)\n", "\n", "We're ready to compute the empirical NTK. The empirical NTK for two data points $x_1$ and $x_2$ is defined as the matrix product between the Jacobian of the model evaluated at $x_1$ and the Jacobian of the model evaluated at $x_2$:\n", "\n", "$$J_{net}(x_1) J_{net}^T(x_2)$$\n", "\n", "In the batched case where $x_1$ is a batch of data points and $x_2$ is a batch of data points, then we want the matrix product between the Jacobians of all combinations of data points from $x_1$ and $x_2$.\n", "\n", "The first method consists of doing just that - computing the two Jacobians, and contracting them. Here's how to compute the NTK in the batched case:" ] }, { "cell_type": "code", "execution_count": 5, "id": "99a38a4b-64d3-4e13-bd63-2d71e8dd6840", "metadata": {}, "outputs": [], "source": [ "def empirical_ntk_jacobian_contraction(fnet_single, params, x1, x2):\n", " # Compute J(x1)\n", " jac1 = vmap(jacrev(fnet_single), (None, 0))(params, x1)\n", " jac1 = [j.flatten(2) for j in jac1]\n", " \n", " # Compute J(x2)\n", " jac2 = vmap(jacrev(fnet_single), (None, 0))(params, x2)\n", " jac2 = [j.flatten(2) for j in jac2]\n", " \n", " # Compute J(x1) @ J(x2).T\n", " result = torch.stack([torch.einsum('Naf,Mbf->NMab', j1, j2) for j1, j2 in zip(jac1, jac2)])\n", " result = result.sum(0)\n", " return result" ] }, { "cell_type": "code", "execution_count": 6, "id": "cbf54d2b-c4bc-46bd-9e55-e1471d639a4e", "metadata": {}, "outputs": [ { "name": "stdout", "output_type": "stream", "text": [ "torch.Size([20, 5, 10, 10])\n" ] } ], "source": [ "result = empirical_ntk_jacobian_contraction(fnet_single, params, x_train, x_test)\n", "print(result.shape)" ] }, { "cell_type": "markdown", "id": "ea844f45-98fb-4cba-8056-644292b968ab", "metadata": {}, "source": [ "In some cases, you may only want the diagonal or the trace of this quantity, especially if you know beforehand that the network architecture results in an NTK where the non-diagonal elements can be approximated by zero. It's easy to adjust the above function to do that:" ] }, { "cell_type": "code", "execution_count": 7, "id": "aae760c9-e906-4fda-b490-1126a86b7e96", "metadata": {}, "outputs": [], "source": [ "def empirical_ntk_jacobian_contraction(fnet_single, params, x1, x2, compute='full'):\n", " # Compute J(x1)\n", " jac1 = vmap(jacrev(fnet_single), (None, 0))(params, x1)\n", " jac1 = [j.flatten(2) for j in jac1]\n", " \n", " # Compute J(x2)\n", " jac2 = vmap(jacrev(fnet_single), (None, 0))(params, x2)\n", " jac2 = [j.flatten(2) for j in jac2]\n", " \n", " # Compute J(x1) @ J(x2).T\n", " einsum_expr = None\n", " if compute == 'full':\n", " einsum_expr = 'Naf,Mbf->NMab'\n", " elif compute == 'trace':\n", " einsum_expr = 'Naf,Maf->NM'\n", " elif compute == 'diagonal':\n", " einsum_expr = 'Naf,Maf->NMa'\n", " else:\n", " assert False\n", " \n", " result = torch.stack([torch.einsum(einsum_expr, j1, j2) for j1, j2 in zip(jac1, jac2)])\n", " result = result.sum(0)\n", " return result" ] }, { "cell_type": "code", "execution_count": 8, "id": "42d974f3-1f9d-4953-8677-5ee22cfc67eb", "metadata": {}, "outputs": [ { "name": "stdout", "output_type": "stream", "text": [ "torch.Size([20, 5])\n" ] } ], "source": [ "result = empirical_ntk_jacobian_contraction(fnet_single, params, x_train, x_test, 'trace')\n", "print(result.shape)" ] }, { "cell_type": "markdown", "id": "6c941e5d-51d7-47b2-80ee-edcd4aee6aaa", "metadata": {}, "source": [ "The asymptotic time complexity of this method is $N O [FP]$ (time to compute the Jacobians) $ + N^2 O^2 P$ (time to contract the Jacobians), where $N$ is the batch size of $x_1$ and $x_2$, $O$ is the model's output size, $P$ is the total number of parameters, and $[FP]$ is the cost of a single forward pass through the model. See section section 3.2 in [Fast Finite Width Neural Tangent Kernel](https://arxiv.org/abs/2206.08720) for details." ] }, { "cell_type": "markdown", "id": "6c931e5d-51d7-47b2-80ee-ddcd4aee6aaa", "metadata": {}, "source": [ "## Compute the NTK: method 2 (NTK-vector products)\n", "\n", "The next method we will discuss is a way to compute the NTK using NTK-vector products.\n", "\n", "This method reformulates NTK as a stack of NTK-vector products applied to columns of an identity matrix $I_O$ of size $O\\times O$ (where $O$ is the output size of the model):\n", "\n", "$$J_{net}(x_1) J_{net}^T(x_2) = J_{net}(x_1) J_{net}^T(x_2) I_{O} = \\left[J_{net}(x_1) \\left[J_{net}^T(x_2) e_o\\right]\\right]_{o=1}^{O},$$\n", "where $e_o\\in \\mathbb{R}^O$ are column vectors of the identity matrix $I_O$.\n", "\n", "- Let $\\textrm{vjp}_o = J_{net}^T(x_2) e_o$. We can use a vector-Jacobian product to compute this.\n", "- Now, consider $J_{net}(x_1) \\textrm{vjp}_o$. This is a Jacobian-vector product!\n", "- Finally, we can run the above computation in parallel over all columns $e_o$ of $I_O$ using `vmap`.\n", "\n", "This suggests that we can use a combination of reverse-mode AD (to compute the vector-Jacobian product) and forward-mode AD (to compute the Jacobian-vector product) to compute the NTK.\n", "\n", "Let's code that up:" ] }, { "cell_type": "code", "execution_count": 9, "id": "dc4b49d7-3096-45d5-a7a1-7032309a2613", "metadata": {}, "outputs": [], "source": [ "def empirical_ntk_ntk_vps(func, params, x1, x2, compute='full'):\n", " def get_ntk(x1, x2):\n", " def func_x1(params):\n", " return func(params, x1)\n", "\n", " def func_x2(params):\n", " return func(params, x2)\n", "\n", " output, vjp_fn = vjp(func_x1, params)\n", "\n", " def get_ntk_slice(vec):\n", " # This computes vec @ J(x2).T\n", " # `vec` is some unit vector (a single slice of the Identity matrix)\n", " vjps = vjp_fn(vec)\n", " # This computes J(X1) @ vjps\n", " _, jvps = jvp(func_x2, (params,), vjps)\n", " return jvps\n", "\n", " # Here's our identity matrix\n", " basis = torch.eye(output.numel(), dtype=output.dtype, device=output.device).view(output.numel(), -1)\n", " return vmap(get_ntk_slice)(basis)\n", " \n", " # get_ntk(x1, x2) computes the NTK for a single data point x1, x2\n", " # Since the x1, x2 inputs to empirical_ntk_ntk_vps are batched,\n", " # we actually wish to compute the NTK between every pair of data points\n", " # between {x1} and {x2}. That's what the vmaps here do.\n", " result = vmap(vmap(get_ntk, (None, 0)), (0, None))(x1, x2)\n", " \n", " if compute == 'full':\n", " return result\n", " if compute == 'trace':\n", " return torch.einsum('NMKK->NM', result)\n", " if compute == 'diagonal':\n", " return torch.einsum('NMKK->NMK', result)" ] }, { "cell_type": "code", "execution_count": 10, "id": "f750544f-9e48-47fe-9f9b-e1b8ae49b245", "metadata": {}, "outputs": [], "source": [ "result_from_jacobian_contraction = empirical_ntk_jacobian_contraction(fnet_single, params, x_test, x_train)\n", "result_from_ntk_vps = empirical_ntk_ntk_vps(fnet_single, params, x_test, x_train)\n", "assert torch.allclose(result_from_jacobian_contraction, result_from_ntk_vps, atol=1e-5)" ] }, { "cell_type": "markdown", "id": "84253466-971d-4475-999c-fe3de6bd25b5", "metadata": {}, "source": [ "Our code for `empirical_ntk_ntk_vps` looks like a direct translation from the math above! This showcases the power of function transforms: good luck trying to write an efficient version of the above using stock PyTorch.\n", "\n", "The asymptotic time complexity of this method is $N^2 O [FP]$, where $N$ is the batch size of $x_1$ and $x_2$, $O$ is the model's output size, and $[FP]$ is the cost of a single forward pass through the model. Hence this method performs more forward passes through the network than method 1, Jacobian contraction ($N^2 O$ instead of $N O$), but avoids the contraction cost altogether (no $N^2 O^2 P$ term, where $P$ is the total number of model's parameters). Therefore, this method is preferable when $O P$ is large relative to $[FP]$, such as fully-connected (not convolutional) models with many outputs $O$. Memory-wise, both methods should be comparable. See section 3.3 in [Fast Finite Width Neural Tangent Kernel](https://arxiv.org/abs/2206.08720) for details." ] } ], "metadata": { "kernelspec": { "display_name": "Python 3 (ipykernel)", "language": "python", "name": "python3" }, "language_info": { "codemirror_mode": { "name": "ipython", "version": 3 }, "file_extension": ".py", "mimetype": "text/x-python", "name": "python", "nbconvert_exporter": "python", "pygments_lexer": "ipython3", "version": "3.9.7" } }, "nbformat": 4, "nbformat_minor": 5 }