{ "cells": [ { "cell_type": "code", "execution_count": null, "metadata": { "collapsed": false }, "outputs": [], "source": [ "# For tips on running notebooks in Google Colab, see\n# https://pytorch.org/tutorials/beginner/colab\n%matplotlib inline" ] }, { "cell_type": "markdown", "metadata": {}, "source": [ "\n# Jacobians, Hessians, hvp, vhp, and more: composing function transforms\n\nComputing jacobians or hessians are useful in a number of non-traditional\ndeep learning models. It is difficult (or annoying) to compute these quantities\nefficiently using PyTorch's regular autodiff APIs\n(Tensor.backward(), torch.autograd.grad). PyTorch's \n[JAX-inspired](https://github.com/google/jax)\n[function transforms API](https://pytorch.org/docs/master/func.html)\nprovides ways of computing various higher-order autodiff quantities\nefficiently.\n\n

#### Note

This tutorial requires PyTorch 2.0.0 or later.

\n\n## Computing the Jacobian\n" ] }, { "cell_type": "code", "execution_count": null, "metadata": { "collapsed": false }, "outputs": [], "source": [ "import torch\nimport torch.nn.functional as F\nfrom functools import partial\n_ = torch.manual_seed(0)" ] }, { "cell_type": "markdown", "metadata": {}, "source": [ "Let's start with a function that we'd like to compute the jacobian of.\nThis is a simple linear function with non-linear activation.\n\n" ] }, { "cell_type": "code", "execution_count": null, "metadata": { "collapsed": false }, "outputs": [], "source": [ "def predict(weight, bias, x):\n return F.linear(x, weight, bias).tanh()" ] }, { "cell_type": "markdown", "metadata": {}, "source": [ "Let's add some dummy data: a weight, a bias, and a feature vector x.\n\n" ] }, { "cell_type": "code", "execution_count": null, "metadata": { "collapsed": false }, "outputs": [], "source": [ "D = 16\nweight = torch.randn(D, D)\nbias = torch.randn(D)\nx = torch.randn(D) # feature vector" ] }, { "cell_type": "markdown", "metadata": {}, "source": [ "Let's think of predict as a function that maps the input x from $R^D \\to R^D$.\nPyTorch Autograd computes vector-Jacobian products. In order to compute the full\nJacobian of this $R^D \\to R^D$ function, we would have to compute it row-by-row\nby using a different unit vector each time.\n\n" ] }, { "cell_type": "code", "execution_count": null, "metadata": { "collapsed": false }, "outputs": [], "source": [ "def compute_jac(xp):\n jacobian_rows = [torch.autograd.grad(predict(weight, bias, xp), xp, vec)\n for vec in unit_vectors]\n return torch.stack(jacobian_rows)\n\nxp = x.clone().requires_grad_()\nunit_vectors = torch.eye(D)\n\njacobian = compute_jac(xp)\n\nprint(jacobian.shape)\nprint(jacobian) # show first row" ] }, { "cell_type": "markdown", "metadata": {}, "source": [ "Instead of computing the jacobian row-by-row, we can use PyTorch's\ntorch.vmap function transform to get rid of the for-loop and vectorize the\ncomputation. We can\u2019t directly apply vmap to torch.autograd.grad;\ninstead, PyTorch provides a torch.func.vjp transform that composes with\ntorch.vmap:\n\n" ] }, { "cell_type": "code", "execution_count": null, "metadata": { "collapsed": false }, "outputs": [], "source": [ "from torch.func import vmap, vjp\n\n_, vjp_fn = vjp(partial(predict, weight, bias), x)\n\nft_jacobian, = vmap(vjp_fn)(unit_vectors)\n\n# let's confirm both methods compute the same result\nassert torch.allclose(ft_jacobian, jacobian)" ] }, { "cell_type": "markdown", "metadata": {}, "source": [ "In a later tutorial a composition of reverse-mode AD and vmap will give us\nper-sample-gradients.\nIn this tutorial, composing reverse-mode AD and vmap gives us Jacobian\ncomputation!\nVarious compositions of vmap and autodiff transforms can give us different\ninteresting quantities.\n\nPyTorch provides torch.func.jacrev as a convenience function that performs\nthe vmap-vjp composition to compute jacobians. jacrev accepts an argnums\nargument that says which argument we would like to compute Jacobians with\nrespect to.\n\n" ] }, { "cell_type": "code", "execution_count": null, "metadata": { "collapsed": false }, "outputs": [], "source": [ "from torch.func import jacrev\n\nft_jacobian = jacrev(predict, argnums=2)(weight, bias, x)\n\n# Confirm by running the following:\nassert torch.allclose(ft_jacobian, jacobian)" ] }, { "cell_type": "markdown", "metadata": {}, "source": [ "Let's compare the performance of the two ways to compute the jacobian.\nThe function transform version is much faster (and becomes even faster the\nmore outputs there are).\n\nIn general, we expect that vectorization via vmap can help eliminate overhead\nand give better utilization of your hardware.\n\nvmap does this magic by pushing the outer loop down into the function's\nprimitive operations in order to obtain better performance.\n\nLet's make a quick function to evaluate performance and deal with\nmicroseconds and milliseconds measurements:\n\n" ] }, { "cell_type": "code", "execution_count": null, "metadata": { "collapsed": false }, "outputs": [], "source": [ "def get_perf(first, first_descriptor, second, second_descriptor):\n \"\"\"takes torch.benchmark objects and compares delta of second vs first.\"\"\"\n faster = second.times\n slower = first.times\n gain = (slower-faster)/slower\n if gain < 0: gain *=-1\n final_gain = gain*100\n print(f\" Performance delta: {final_gain:.4f} percent improvement with {second_descriptor} \")" ] }, { "cell_type": "markdown", "metadata": {}, "source": [ "And then run the performance comparison:\n\n" ] }, { "cell_type": "code", "execution_count": null, "metadata": { "collapsed": false }, "outputs": [], "source": [ "from torch.utils.benchmark import Timer\n\nwithout_vmap = Timer(stmt=\"compute_jac(xp)\", globals=globals())\nwith_vmap = Timer(stmt=\"jacrev(predict, argnums=2)(weight, bias, x)\", globals=globals())\n\nno_vmap_timer = without_vmap.timeit(500)\nwith_vmap_timer = with_vmap.timeit(500)\n\nprint(no_vmap_timer)\nprint(with_vmap_timer)" ] }, { "cell_type": "markdown", "metadata": {}, "source": [ "Let's do a relative performance comparison of the above with our get_perf function:\n\n" ] }, { "cell_type": "code", "execution_count": null, "metadata": { "collapsed": false }, "outputs": [], "source": [ "get_perf(no_vmap_timer, \"without vmap\", with_vmap_timer, \"vmap\")" ] }, { "cell_type": "markdown", "metadata": {}, "source": [ "Furthermore, it\u2019s pretty easy to flip the problem around and say we want to\ncompute Jacobians of the parameters to our model (weight, bias) instead of the input\n\n" ] }, { "cell_type": "code", "execution_count": null, "metadata": { "collapsed": false }, "outputs": [], "source": [ "# note the change in input via argnums parameters of 0,1 to map to weight and bias\nft_jac_weight, ft_jac_bias = jacrev(predict, argnums=(0, 1))(weight, bias, x)" ] }, { "cell_type": "markdown", "metadata": {}, "source": [ "## Reverse-mode Jacobian (jacrev) vs forward-mode Jacobian (jacfwd)\n\nWe offer two APIs to compute jacobians: jacrev and jacfwd:\n\n- jacrev uses reverse-mode AD. As you saw above it is a composition of our\n vjp and vmap transforms.\n- jacfwd uses forward-mode AD. It is implemented as a composition of our\n jvp and vmap transforms.\n\njacfwd and jacrev can be substituted for each other but they have different\nperformance characteristics.\n\nAs a general rule of thumb, if you\u2019re computing the jacobian of an $R^N \\to R^M$\nfunction, and there are many more outputs than inputs (for example, $M > N$) then\njacfwd is preferred, otherwise use jacrev. There are exceptions to this rule,\nbut a non-rigorous argument for this follows:\n\nIn reverse-mode AD, we are computing the jacobian row-by-row, while in\nforward-mode AD (which computes Jacobian-vector products), we are computing\nit column-by-column. The Jacobian matrix has M rows and N columns, so if it\nis taller or wider one way we may prefer the method that deals with fewer\nrows or columns.\n\n" ] }, { "cell_type": "code", "execution_count": null, "metadata": { "collapsed": false }, "outputs": [], "source": [ "from torch.func import jacrev, jacfwd" ] }, { "cell_type": "markdown", "metadata": {}, "source": [ "First, let's benchmark with more inputs than outputs:\n\n" ] }, { "cell_type": "code", "execution_count": null, "metadata": { "collapsed": false }, "outputs": [], "source": [ "Din = 32\nDout = 2048\nweight = torch.randn(Dout, Din)\n\nbias = torch.randn(Dout)\nx = torch.randn(Din)\n\n# remember the general rule about taller vs wider... here we have a taller matrix:\nprint(weight.shape)\n\nusing_fwd = Timer(stmt=\"jacfwd(predict, argnums=2)(weight, bias, x)\", globals=globals())\nusing_bwd = Timer(stmt=\"jacrev(predict, argnums=2)(weight, bias, x)\", globals=globals())\n\njacfwd_timing = using_fwd.timeit(500)\njacrev_timing = using_bwd.timeit(500)\n\nprint(f'jacfwd time: {jacfwd_timing}')\nprint(f'jacrev time: {jacrev_timing}')" ] }, { "cell_type": "markdown", "metadata": {}, "source": [ "and then do a relative benchmark:\n\n" ] }, { "cell_type": "code", "execution_count": null, "metadata": { "collapsed": false }, "outputs": [], "source": [ "get_perf(jacfwd_timing, \"jacfwd\", jacrev_timing, \"jacrev\", );" ] }, { "cell_type": "markdown", "metadata": {}, "source": [ "and now the reverse - more outputs (M) than inputs (N):\n\n" ] }, { "cell_type": "code", "execution_count": null, "metadata": { "collapsed": false }, "outputs": [], "source": [ "Din = 2048\nDout = 32\nweight = torch.randn(Dout, Din)\nbias = torch.randn(Dout)\nx = torch.randn(Din)\n\nusing_fwd = Timer(stmt=\"jacfwd(predict, argnums=2)(weight, bias, x)\", globals=globals())\nusing_bwd = Timer(stmt=\"jacrev(predict, argnums=2)(weight, bias, x)\", globals=globals())\n\njacfwd_timing = using_fwd.timeit(500)\njacrev_timing = using_bwd.timeit(500)\n\nprint(f'jacfwd time: {jacfwd_timing}')\nprint(f'jacrev time: {jacrev_timing}')" ] }, { "cell_type": "markdown", "metadata": {}, "source": [ "and a relative performance comparison:\n\n" ] }, { "cell_type": "code", "execution_count": null, "metadata": { "collapsed": false }, "outputs": [], "source": [ "get_perf(jacrev_timing, \"jacrev\", jacfwd_timing, \"jacfwd\")" ] }, { "cell_type": "markdown", "metadata": {}, "source": [ "## Hessian computation with functorch.hessian\nWe offer a convenience API to compute hessians: torch.func.hessiani.\nHessians are the jacobian of the jacobian (or the partial derivative of\nthe partial derivative, aka second order).\n\nThis suggests that one can just compose functorch jacobian transforms to\ncompute the Hessian.\nIndeed, under the hood, hessian(f) is simply jacfwd(jacrev(f)).\n\nNote: to boost performance: depending on your model, you may also want to\nuse jacfwd(jacfwd(f)) or jacrev(jacrev(f)) instead to compute hessians\nleveraging the rule of thumb above regarding wider vs taller matrices.\n\n" ] }, { "cell_type": "code", "execution_count": null, "metadata": { "collapsed": false }, "outputs": [], "source": [ "from torch.func import hessian\n\n# lets reduce the size in order not to overwhelm Colab. Hessians require\n# significant memory:\nDin = 512\nDout = 32\nweight = torch.randn(Dout, Din)\nbias = torch.randn(Dout)\nx = torch.randn(Din)\n\nhess_api = hessian(predict, argnums=2)(weight, bias, x)\nhess_fwdfwd = jacfwd(jacfwd(predict, argnums=2), argnums=2)(weight, bias, x)\nhess_revrev = jacrev(jacrev(predict, argnums=2), argnums=2)(weight, bias, x)" ] }, { "cell_type": "markdown", "metadata": {}, "source": [ "Let's verify we have the same result regardless of using hessian API or\nusing jacfwd(jacfwd()).\n\n" ] }, { "cell_type": "code", "execution_count": null, "metadata": { "collapsed": false }, "outputs": [], "source": [ "torch.allclose(hess_api, hess_fwdfwd)" ] }, { "cell_type": "markdown", "metadata": {}, "source": [ "## Batch Jacobian and Batch Hessian\nIn the above examples we\u2019ve been operating with a single feature vector.\nIn some cases you might want to take the Jacobian of a batch of outputs\nwith respect to a batch of inputs. That is, given a batch of inputs of\nshape (B, N) and a function that goes from $R^N \\to R^M$, we would like\na Jacobian of shape (B, M, N).\n\nThe easiest way to do this is to use vmap:\n\n" ] }, { "cell_type": "code", "execution_count": null, "metadata": { "collapsed": false }, "outputs": [], "source": [ "batch_size = 64\nDin = 31\nDout = 33\n\nweight = torch.randn(Dout, Din)\nprint(f\"weight shape = {weight.shape}\")\n\nbias = torch.randn(Dout)\n\nx = torch.randn(batch_size, Din)\n\ncompute_batch_jacobian = vmap(jacrev(predict, argnums=2), in_dims=(None, None, 0))\nbatch_jacobian0 = compute_batch_jacobian(weight, bias, x)" ] }, { "cell_type": "markdown", "metadata": {}, "source": [ "If you have a function that goes from (B, N) -> (B, M) instead and are\ncertain that each input produces an independent output, then it's also\nsometimes possible to do this without using vmap by summing the outputs\nand then computing the Jacobian of that function:\n\n" ] }, { "cell_type": "code", "execution_count": null, "metadata": { "collapsed": false }, "outputs": [], "source": [ "def predict_with_output_summed(weight, bias, x):\n return predict(weight, bias, x).sum(0)\n\nbatch_jacobian1 = jacrev(predict_with_output_summed, argnums=2)(weight, bias, x).movedim(1, 0)\nassert torch.allclose(batch_jacobian0, batch_jacobian1)" ] }, { "cell_type": "markdown", "metadata": {}, "source": [ "If you instead have a function that goes from $R^N \\to R^M$ but inputs that\nare batched, you compose vmap with jacrev to compute batched jacobians:\n\nFinally, batch hessians can be computed similarly. It's easiest to think\nabout them by using vmap to batch over hessian computation, but in some\ncases the sum trick also works.\n\n" ] }, { "cell_type": "code", "execution_count": null, "metadata": { "collapsed": false }, "outputs": [], "source": [ "compute_batch_hessian = vmap(hessian(predict, argnums=2), in_dims=(None, None, 0))\n\nbatch_hess = compute_batch_hessian(weight, bias, x)\nbatch_hess.shape" ] }, { "cell_type": "markdown", "metadata": {}, "source": [ "## Computing Hessian-vector products\nThe naive way to compute a Hessian-vector product (hvp) is to materialize\nthe full Hessian and perform a dot-product with a vector. We can do better:\nit turns out we don't need to materialize the full Hessian to do this. We'll\ngo through two (of many) different strategies to compute Hessian-vector products:\n- composing reverse-mode AD with reverse-mode AD\n- composing reverse-mode AD with forward-mode AD\n\nComposing reverse-mode AD with forward-mode AD (as opposed to reverse-mode\nwith reverse-mode) is generally the more memory efficient way to compute a\nhvp because forward-mode AD doesn't need to construct an Autograd graph and\nsave intermediates for backward:\n\n" ] }, { "cell_type": "code", "execution_count": null, "metadata": { "collapsed": false }, "outputs": [], "source": [ "from torch.func import jvp, grad, vjp\n\ndef hvp(f, primals, tangents):\n return jvp(grad(f), primals, tangents)" ] }, { "cell_type": "markdown", "metadata": {}, "source": [ "Here's some sample usage.\n\n" ] }, { "cell_type": "code", "execution_count": null, "metadata": { "collapsed": false }, "outputs": [], "source": [ "def f(x):\n return x.sin().sum()\n\nx = torch.randn(2048)\ntangent = torch.randn(2048)\n\nresult = hvp(f, (x,), (tangent,))" ] }, { "cell_type": "markdown", "metadata": {}, "source": [ "If PyTorch forward-AD does not have coverage for your operations, then we can\ninstead compose reverse-mode AD with reverse-mode AD:\n\n" ] }, { "cell_type": "code", "execution_count": null, "metadata": { "collapsed": false }, "outputs": [], "source": [ "def hvp_revrev(f, primals, tangents):\n _, vjp_fn = vjp(grad(f), *primals)\n return vjp_fn(*tangents)\n\nresult_hvp_revrev = hvp_revrev(f, (x,), (tangent,))\nassert torch.allclose(result, result_hvp_revrev)" ] } ], "metadata": { "kernelspec": { "display_name": "Python 3", "language": "python", "name": "python3" }, "language_info": { "codemirror_mode": { "name": "ipython", "version": 3 }, "file_extension": ".py", "mimetype": "text/x-python", "name": "python", "nbconvert_exporter": "python", "pygments_lexer": "ipython3", "version": "3.10.10" } }, "nbformat": 4, "nbformat_minor": 0 }