{ "cells": [ { "cell_type": "code", "execution_count": null, "metadata": { "collapsed": false }, "outputs": [], "source": [ "# For tips on running notebooks in Google Colab, see\n", "# https://pytorch.org/tutorials/beginner/colab\n", "%matplotlib inline" ] }, { "cell_type": "markdown", "metadata": {}, "source": [ "Neural Tangent Kernels\n", "======================\n", "\n", "The neural tangent kernel (NTK) is a kernel that describes [how a neural\n", "network evolves during\n", "training](https://en.wikipedia.org/wiki/Neural_tangent_kernel). There\n", "has been a lot of research around it [in recent\n", "years](https://arxiv.org/abs/1806.07572). This tutorial, inspired by the\n", "implementation of [NTKs in\n", "JAX](https://github.com/google/neural-tangents) (see [Fast Finite Width\n", "Neural Tangent Kernel](https://arxiv.org/abs/2206.08720) for details),\n", "demonstrates how to easily compute this quantity using `torch.func`,\n", "composable function transforms for PyTorch.\n", "\n", "
This tutorial requires PyTorch 2.0.0 or later.
\n", "