.. DO NOT EDIT. .. THIS FILE WAS AUTOMATICALLY GENERATED BY SPHINX-GALLERY. .. TO MAKE CHANGES, EDIT THE SOURCE PYTHON FILE: .. "intermediate/jacobians_hessians.py" .. LINE NUMBERS ARE GIVEN BELOW. .. only:: html .. note:: :class: sphx-glr-download-link-note Click :ref:`here ` to download the full example code .. rst-class:: sphx-glr-example-title .. _sphx_glr_intermediate_jacobians_hessians.py: Jacobians, Hessians, hvp, vhp, and more: composing function transforms ====================================================================== Computing jacobians or hessians are useful in a number of non-traditional deep learning models. It is difficult (or annoying) to compute these quantities efficiently using PyTorch's regular autodiff APIs (``Tensor.backward()``, ``torch.autograd.grad``). PyTorch's `JAX-inspired `_ `function transforms API `_ provides ways of computing various higher-order autodiff quantities efficiently. .. note:: This tutorial requires PyTorch 2.0.0 or later. Computing the Jacobian ---------------------- .. GENERATED FROM PYTHON SOURCE LINES 22-28 .. code-block:: default import torch import torch.nn.functional as F from functools import partial _ = torch.manual_seed(0) .. GENERATED FROM PYTHON SOURCE LINES 29-31 Let's start with a function that we'd like to compute the jacobian of. This is a simple linear function with non-linear activation. .. GENERATED FROM PYTHON SOURCE LINES 31-35 .. code-block:: default def predict(weight, bias, x): return F.linear(x, weight, bias).tanh() .. GENERATED FROM PYTHON SOURCE LINES 36-37 Let's add some dummy data: a weight, a bias, and a feature vector x. .. GENERATED FROM PYTHON SOURCE LINES 37-43 .. code-block:: default D = 16 weight = torch.randn(D, D) bias = torch.randn(D) x = torch.randn(D) # feature vector .. GENERATED FROM PYTHON SOURCE LINES 44-48 Let's think of ``predict`` as a function that maps the input ``x`` from :math:`R^D \to R^D`. PyTorch Autograd computes vector-Jacobian products. In order to compute the full Jacobian of this :math:`R^D \to R^D` function, we would have to compute it row-by-row by using a different unit vector each time. .. GENERATED FROM PYTHON SOURCE LINES 48-62 .. code-block:: default def compute_jac(xp): jacobian_rows = [torch.autograd.grad(predict(weight, bias, xp), xp, vec)[0] for vec in unit_vectors] return torch.stack(jacobian_rows) xp = x.clone().requires_grad_() unit_vectors = torch.eye(D) jacobian = compute_jac(xp) print(jacobian.shape) print(jacobian[0]) # show first row .. rst-class:: sphx-glr-script-out .. code-block:: none torch.Size([16, 16]) tensor([-0.5956, -0.6096, -0.1326, -0.2295, 0.4490, 0.3661, -0.1672, -1.1190, 0.1705, -0.6683, 0.1851, 0.1630, 0.0634, 0.6547, 0.5908, -0.1308]) .. GENERATED FROM PYTHON SOURCE LINES 63-68 Instead of computing the jacobian row-by-row, we can use PyTorch's ``torch.vmap`` function transform to get rid of the for-loop and vectorize the computation. We can’t directly apply ``vmap`` to ``torch.autograd.grad``; instead, PyTorch provides a ``torch.func.vjp`` transform that composes with ``torch.vmap``: .. GENERATED FROM PYTHON SOURCE LINES 68-78 .. code-block:: default from torch.func import vmap, vjp _, vjp_fn = vjp(partial(predict, weight, bias), x) ft_jacobian, = vmap(vjp_fn)(unit_vectors) # let's confirm both methods compute the same result assert torch.allclose(ft_jacobian, jacobian) .. GENERATED FROM PYTHON SOURCE LINES 79-90 In a later tutorial a composition of reverse-mode AD and ``vmap`` will give us per-sample-gradients. In this tutorial, composing reverse-mode AD and ``vmap`` gives us Jacobian computation! Various compositions of ``vmap`` and autodiff transforms can give us different interesting quantities. PyTorch provides ``torch.func.jacrev`` as a convenience function that performs the ``vmap-vjp`` composition to compute jacobians. ``jacrev`` accepts an ``argnums`` argument that says which argument we would like to compute Jacobians with respect to. .. GENERATED FROM PYTHON SOURCE LINES 90-98 .. code-block:: default from torch.func import jacrev ft_jacobian = jacrev(predict, argnums=2)(weight, bias, x) # Confirm by running the following: assert torch.allclose(ft_jacobian, jacobian) .. GENERATED FROM PYTHON SOURCE LINES 99-111 Let's compare the performance of the two ways to compute the jacobian. The function transform version is much faster (and becomes even faster the more outputs there are). In general, we expect that vectorization via ``vmap`` can help eliminate overhead and give better utilization of your hardware. ``vmap`` does this magic by pushing the outer loop down into the function's primitive operations in order to obtain better performance. Let's make a quick function to evaluate performance and deal with microseconds and milliseconds measurements: .. GENERATED FROM PYTHON SOURCE LINES 111-121 .. code-block:: default def get_perf(first, first_descriptor, second, second_descriptor): """takes torch.benchmark objects and compares delta of second vs first.""" faster = second.times[0] slower = first.times[0] gain = (slower-faster)/slower if gain < 0: gain *=-1 final_gain = gain*100 print(f" Performance delta: {final_gain:.4f} percent improvement with {second_descriptor} ") .. GENERATED FROM PYTHON SOURCE LINES 122-123 And then run the performance comparison: .. GENERATED FROM PYTHON SOURCE LINES 123-135 .. code-block:: default from torch.utils.benchmark import Timer without_vmap = Timer(stmt="compute_jac(xp)", globals=globals()) with_vmap = Timer(stmt="jacrev(predict, argnums=2)(weight, bias, x)", globals=globals()) no_vmap_timer = without_vmap.timeit(500) with_vmap_timer = with_vmap.timeit(500) print(no_vmap_timer) print(with_vmap_timer) .. rst-class:: sphx-glr-script-out .. code-block:: none compute_jac(xp) 1.44 ms 1 measurement, 500 runs , 1 thread jacrev(predict, argnums=2)(weight, bias, x) 412.06 us 1 measurement, 500 runs , 1 thread .. GENERATED FROM PYTHON SOURCE LINES 136-137 Let's do a relative performance comparison of the above with our ``get_perf`` function: .. GENERATED FROM PYTHON SOURCE LINES 137-140 .. code-block:: default get_perf(no_vmap_timer, "without vmap", with_vmap_timer, "vmap") .. rst-class:: sphx-glr-script-out .. code-block:: none Performance delta: 71.4022 percent improvement with vmap .. GENERATED FROM PYTHON SOURCE LINES 141-143 Furthermore, it’s pretty easy to flip the problem around and say we want to compute Jacobians of the parameters to our model (weight, bias) instead of the input .. GENERATED FROM PYTHON SOURCE LINES 143-147 .. code-block:: default # note the change in input via ``argnums`` parameters of 0,1 to map to weight and bias ft_jac_weight, ft_jac_bias = jacrev(predict, argnums=(0, 1))(weight, bias, x) .. GENERATED FROM PYTHON SOURCE LINES 148-171 Reverse-mode Jacobian (``jacrev``) vs forward-mode Jacobian (``jacfwd``) ------------------------------------------------------------------------ We offer two APIs to compute jacobians: ``jacrev`` and ``jacfwd``: - ``jacrev`` uses reverse-mode AD. As you saw above it is a composition of our ``vjp`` and ``vmap`` transforms. - ``jacfwd`` uses forward-mode AD. It is implemented as a composition of our ``jvp`` and ``vmap`` transforms. ``jacfwd`` and ``jacrev`` can be substituted for each other but they have different performance characteristics. As a general rule of thumb, if you’re computing the jacobian of an :math:`R^N \to R^M` function, and there are many more outputs than inputs (for example, :math:`M > N`) then ``jacfwd`` is preferred, otherwise use ``jacrev``. There are exceptions to this rule, but a non-rigorous argument for this follows: In reverse-mode AD, we are computing the jacobian row-by-row, while in forward-mode AD (which computes Jacobian-vector products), we are computing it column-by-column. The Jacobian matrix has M rows and N columns, so if it is taller or wider one way we may prefer the method that deals with fewer rows or columns. .. GENERATED FROM PYTHON SOURCE LINES 171-174 .. code-block:: default from torch.func import jacrev, jacfwd .. GENERATED FROM PYTHON SOURCE LINES 175-176 First, let's benchmark with more inputs than outputs: .. GENERATED FROM PYTHON SOURCE LINES 176-196 .. code-block:: default Din = 32 Dout = 2048 weight = torch.randn(Dout, Din) bias = torch.randn(Dout) x = torch.randn(Din) # remember the general rule about taller vs wider... here we have a taller matrix: print(weight.shape) using_fwd = Timer(stmt="jacfwd(predict, argnums=2)(weight, bias, x)", globals=globals()) using_bwd = Timer(stmt="jacrev(predict, argnums=2)(weight, bias, x)", globals=globals()) jacfwd_timing = using_fwd.timeit(500) jacrev_timing = using_bwd.timeit(500) print(f'jacfwd time: {jacfwd_timing}') print(f'jacrev time: {jacrev_timing}') .. rst-class:: sphx-glr-script-out .. code-block:: none torch.Size([2048, 32]) jacfwd time: jacfwd(predict, argnums=2)(weight, bias, x) 768.65 us 1 measurement, 500 runs , 1 thread jacrev time: jacrev(predict, argnums=2)(weight, bias, x) 8.64 ms 1 measurement, 500 runs , 1 thread .. GENERATED FROM PYTHON SOURCE LINES 197-198 and then do a relative benchmark: .. GENERATED FROM PYTHON SOURCE LINES 198-201 .. code-block:: default get_perf(jacfwd_timing, "jacfwd", jacrev_timing, "jacrev", ); .. rst-class:: sphx-glr-script-out .. code-block:: none Performance delta: 1023.7595 percent improvement with jacrev .. GENERATED FROM PYTHON SOURCE LINES 202-203 and now the reverse - more outputs (M) than inputs (N): .. GENERATED FROM PYTHON SOURCE LINES 203-219 .. code-block:: default Din = 2048 Dout = 32 weight = torch.randn(Dout, Din) bias = torch.randn(Dout) x = torch.randn(Din) using_fwd = Timer(stmt="jacfwd(predict, argnums=2)(weight, bias, x)", globals=globals()) using_bwd = Timer(stmt="jacrev(predict, argnums=2)(weight, bias, x)", globals=globals()) jacfwd_timing = using_fwd.timeit(500) jacrev_timing = using_bwd.timeit(500) print(f'jacfwd time: {jacfwd_timing}') print(f'jacrev time: {jacrev_timing}') .. rst-class:: sphx-glr-script-out .. code-block:: none jacfwd time: jacfwd(predict, argnums=2)(weight, bias, x) 7.02 ms 1 measurement, 500 runs , 1 thread jacrev time: jacrev(predict, argnums=2)(weight, bias, x) 495.24 us 1 measurement, 500 runs , 1 thread .. GENERATED FROM PYTHON SOURCE LINES 220-221 and a relative performance comparison: .. GENERATED FROM PYTHON SOURCE LINES 221-224 .. code-block:: default get_perf(jacrev_timing, "jacrev", jacfwd_timing, "jacfwd") .. rst-class:: sphx-glr-script-out .. code-block:: none Performance delta: 1316.6781 percent improvement with jacfwd .. GENERATED FROM PYTHON SOURCE LINES 225-238 Hessian computation with functorch.hessian ------------------------------------------ We offer a convenience API to compute hessians: ``torch.func.hessiani``. Hessians are the jacobian of the jacobian (or the partial derivative of the partial derivative, aka second order). This suggests that one can just compose functorch jacobian transforms to compute the Hessian. Indeed, under the hood, ``hessian(f)`` is simply ``jacfwd(jacrev(f))``. Note: to boost performance: depending on your model, you may also want to use ``jacfwd(jacfwd(f))`` or ``jacrev(jacrev(f))`` instead to compute hessians leveraging the rule of thumb above regarding wider vs taller matrices. .. GENERATED FROM PYTHON SOURCE LINES 238-253 .. code-block:: default from torch.func import hessian # lets reduce the size in order not to overwhelm Colab. Hessians require # significant memory: Din = 512 Dout = 32 weight = torch.randn(Dout, Din) bias = torch.randn(Dout) x = torch.randn(Din) hess_api = hessian(predict, argnums=2)(weight, bias, x) hess_fwdfwd = jacfwd(jacfwd(predict, argnums=2), argnums=2)(weight, bias, x) hess_revrev = jacrev(jacrev(predict, argnums=2), argnums=2)(weight, bias, x) .. GENERATED FROM PYTHON SOURCE LINES 254-256 Let's verify we have the same result regardless of using hessian API or using ``jacfwd(jacfwd())``. .. GENERATED FROM PYTHON SOURCE LINES 256-259 .. code-block:: default torch.allclose(hess_api, hess_fwdfwd) .. rst-class:: sphx-glr-script-out .. code-block:: none True .. GENERATED FROM PYTHON SOURCE LINES 260-269 Batch Jacobian and Batch Hessian -------------------------------- In the above examples we’ve been operating with a single feature vector. In some cases you might want to take the Jacobian of a batch of outputs with respect to a batch of inputs. That is, given a batch of inputs of shape ``(B, N)`` and a function that goes from :math:`R^N \to R^M`, we would like a Jacobian of shape ``(B, M, N)``. The easiest way to do this is to use ``vmap``: .. GENERATED FROM PYTHON SOURCE LINES 269-284 .. code-block:: default batch_size = 64 Din = 31 Dout = 33 weight = torch.randn(Dout, Din) print(f"weight shape = {weight.shape}") bias = torch.randn(Dout) x = torch.randn(batch_size, Din) compute_batch_jacobian = vmap(jacrev(predict, argnums=2), in_dims=(None, None, 0)) batch_jacobian0 = compute_batch_jacobian(weight, bias, x) .. rst-class:: sphx-glr-script-out .. code-block:: none weight shape = torch.Size([33, 31]) .. GENERATED FROM PYTHON SOURCE LINES 285-289 If you have a function that goes from (B, N) -> (B, M) instead and are certain that each input produces an independent output, then it's also sometimes possible to do this without using ``vmap`` by summing the outputs and then computing the Jacobian of that function: .. GENERATED FROM PYTHON SOURCE LINES 289-296 .. code-block:: default def predict_with_output_summed(weight, bias, x): return predict(weight, bias, x).sum(0) batch_jacobian1 = jacrev(predict_with_output_summed, argnums=2)(weight, bias, x).movedim(1, 0) assert torch.allclose(batch_jacobian0, batch_jacobian1) .. GENERATED FROM PYTHON SOURCE LINES 297-303 If you instead have a function that goes from :math:`R^N \to R^M` but inputs that are batched, you compose ``vmap`` with ``jacrev`` to compute batched jacobians: Finally, batch hessians can be computed similarly. It's easiest to think about them by using ``vmap`` to batch over hessian computation, but in some cases the sum trick also works. .. GENERATED FROM PYTHON SOURCE LINES 303-309 .. code-block:: default compute_batch_hessian = vmap(hessian(predict, argnums=2), in_dims=(None, None, 0)) batch_hess = compute_batch_hessian(weight, bias, x) batch_hess.shape .. rst-class:: sphx-glr-script-out .. code-block:: none torch.Size([64, 33, 31, 31]) .. GENERATED FROM PYTHON SOURCE LINES 310-323 Computing Hessian-vector products --------------------------------- The naive way to compute a Hessian-vector product (hvp) is to materialize the full Hessian and perform a dot-product with a vector. We can do better: it turns out we don't need to materialize the full Hessian to do this. We'll go through two (of many) different strategies to compute Hessian-vector products: - composing reverse-mode AD with reverse-mode AD - composing reverse-mode AD with forward-mode AD Composing reverse-mode AD with forward-mode AD (as opposed to reverse-mode with reverse-mode) is generally the more memory efficient way to compute a hvp because forward-mode AD doesn't need to construct an Autograd graph and save intermediates for backward: .. GENERATED FROM PYTHON SOURCE LINES 323-329 .. code-block:: default from torch.func import jvp, grad, vjp def hvp(f, primals, tangents): return jvp(grad(f), primals, tangents)[1] .. GENERATED FROM PYTHON SOURCE LINES 330-331 Here's some sample usage. .. GENERATED FROM PYTHON SOURCE LINES 331-340 .. code-block:: default def f(x): return x.sin().sum() x = torch.randn(2048) tangent = torch.randn(2048) result = hvp(f, (x,), (tangent,)) .. GENERATED FROM PYTHON SOURCE LINES 341-343 If PyTorch forward-AD does not have coverage for your operations, then we can instead compose reverse-mode AD with reverse-mode AD: .. GENERATED FROM PYTHON SOURCE LINES 343-350 .. code-block:: default def hvp_revrev(f, primals, tangents): _, vjp_fn = vjp(grad(f), *primals) return vjp_fn(*tangents) result_hvp_revrev = hvp_revrev(f, (x,), (tangent,)) assert torch.allclose(result, result_hvp_revrev[0]) .. rst-class:: sphx-glr-timing **Total running time of the script:** ( 0 minutes 10.504 seconds) .. _sphx_glr_download_intermediate_jacobians_hessians.py: .. only:: html .. container:: sphx-glr-footer sphx-glr-footer-example .. container:: sphx-glr-download sphx-glr-download-python :download:`Download Python source code: jacobians_hessians.py ` .. container:: sphx-glr-download sphx-glr-download-jupyter :download:`Download Jupyter notebook: jacobians_hessians.ipynb ` .. only:: html .. rst-class:: sphx-glr-signature `Gallery generated by Sphinx-Gallery `_