{
"cells": [
{
"cell_type": "code",
"execution_count": null,
"metadata": {
"collapsed": false
},
"outputs": [],
"source": [
"# For tips on running notebooks in Google Colab, see\n",
"# https://pytorch.org/tutorials/beginner/colab\n",
"%matplotlib inline"
]
},
{
"cell_type": "markdown",
"metadata": {},
"source": [
"Forward-mode Automatic Differentiation (Beta)\n",
"=============================================\n",
"\n",
"This tutorial demonstrates how to use forward-mode AD to compute\n",
"directional derivatives (or equivalently, Jacobian-vector products).\n",
"\n",
"The tutorial below uses some APIs only available in versions \\>= 1.11\n",
"(or nightly builds).\n",
"\n",
"Also note that forward-mode AD is currently in beta. The API is subject\n",
"to change and operator coverage is still incomplete.\n",
"\n",
"Basic Usage\n",
"-----------\n",
"\n",
"Unlike reverse-mode AD, forward-mode AD computes gradients eagerly\n",
"alongside the forward pass. We can use forward-mode AD to compute a\n",
"directional derivative by performing the forward pass as before, except\n",
"we first associate our input with another tensor representing the\n",
"direction of the directional derivative (or equivalently, the `v` in a\n",
"Jacobian-vector product). When an input, which we call \\\"primal\\\", is\n",
"associated with a \\\"direction\\\" tensor, which we call \\\"tangent\\\", the\n",
"resultant new tensor object is called a \\\"dual tensor\\\" for its\n",
"connection to dual numbers\\[0\\].\n",
"\n",
"As the forward pass is performed, if any input tensors are dual tensors,\n",
"extra computation is performed to propagate this \\\"sensitivity\\\" of the\n",
"function.\n"
]
},
{
"cell_type": "code",
"execution_count": null,
"metadata": {
"collapsed": false
},
"outputs": [],
"source": [
"import torch\n",
"import torch.autograd.forward_ad as fwAD\n",
"\n",
"primal = torch.randn(10, 10)\n",
"tangent = torch.randn(10, 10)\n",
"\n",
"def fn(x, y):\n",
" return x ** 2 + y ** 2\n",
"\n",
"# All forward AD computation must be performed in the context of\n",
"# a ``dual_level`` context. All dual tensors created in such a context\n",
"# will have their tangents destroyed upon exit. This is to ensure that\n",
"# if the output or intermediate results of this computation are reused\n",
"# in a future forward AD computation, their tangents (which are associated\n",
"# with this computation) won't be confused with tangents from the later\n",
"# computation.\n",
"with fwAD.dual_level():\n",
" # To create a dual tensor we associate a tensor, which we call the\n",
" # primal with another tensor of the same size, which we call the tangent.\n",
" # If the layout of the tangent is different from that of the primal,\n",
" # The values of the tangent are copied into a new tensor with the same\n",
" # metadata as the primal. Otherwise, the tangent itself is used as-is.\n",
" #\n",
" # It is also important to note that the dual tensor created by\n",
" # ``make_dual`` is a view of the primal.\n",
" dual_input = fwAD.make_dual(primal, tangent)\n",
" assert fwAD.unpack_dual(dual_input).tangent is tangent\n",
"\n",
" # To demonstrate the case where the copy of the tangent happens,\n",
" # we pass in a tangent with a layout different from that of the primal\n",
" dual_input_alt = fwAD.make_dual(primal, tangent.T)\n",
" assert fwAD.unpack_dual(dual_input_alt).tangent is not tangent\n",
"\n",
" # Tensors that do not have an associated tangent are automatically\n",
" # considered to have a zero-filled tangent of the same shape.\n",
" plain_tensor = torch.randn(10, 10)\n",
" dual_output = fn(dual_input, plain_tensor)\n",
"\n",
" # Unpacking the dual returns a ``namedtuple`` with ``primal`` and ``tangent``\n",
" # as attributes\n",
" jvp = fwAD.unpack_dual(dual_output).tangent\n",
"\n",
"assert fwAD.unpack_dual(dual_output).tangent is None"
]
},
{
"cell_type": "markdown",
"metadata": {},
"source": [
"Usage with Modules\n",
"==================\n",
"\n",
"To use `nn.Module` with forward AD, replace the parameters of your model\n",
"with dual tensors before performing the forward pass. At the time of\n",
"writing, it is not possible to create dual tensor \\`nn.Parameter\\`s. As\n",
"a workaround, one must register the dual tensor as a non-parameter\n",
"attribute of the module.\n"
]
},
{
"cell_type": "code",
"execution_count": null,
"metadata": {
"collapsed": false
},
"outputs": [],
"source": [
"import torch.nn as nn\n",
"\n",
"model = nn.Linear(5, 5)\n",
"input = torch.randn(16, 5)\n",
"\n",
"params = {name: p for name, p in model.named_parameters()}\n",
"tangents = {name: torch.rand_like(p) for name, p in params.items()}\n",
"\n",
"with fwAD.dual_level():\n",
" for name, p in params.items():\n",
" delattr(model, name)\n",
" setattr(model, name, fwAD.make_dual(p, tangents[name]))\n",
"\n",
" out = model(input)\n",
" jvp = fwAD.unpack_dual(out).tangent"
]
},
{
"cell_type": "markdown",
"metadata": {},
"source": [
"Using the functional Module API (beta)\n",
"======================================\n",
"\n",
"Another way to use `nn.Module` with forward AD is to utilize the\n",
"functional Module API (also known as the stateless Module API).\n"
]
},
{
"cell_type": "code",
"execution_count": null,
"metadata": {
"collapsed": false
},
"outputs": [],
"source": [
"from torch.func import functional_call\n",
"\n",
"# We need a fresh module because the functional call requires the\n",
"# the model to have parameters registered.\n",
"model = nn.Linear(5, 5)\n",
"\n",
"dual_params = {}\n",
"with fwAD.dual_level():\n",
" for name, p in params.items():\n",
" # Using the same ``tangents`` from the above section\n",
" dual_params[name] = fwAD.make_dual(p, tangents[name])\n",
" out = functional_call(model, dual_params, input)\n",
" jvp2 = fwAD.unpack_dual(out).tangent\n",
"\n",
"# Check our results\n",
"assert torch.allclose(jvp, jvp2)"
]
},
{
"cell_type": "markdown",
"metadata": {},
"source": [
"Custom autograd Function\n",
"========================\n",
"\n",
"Custom Functions also support forward-mode AD. To create custom Function\n",
"supporting forward-mode AD, register the `jvp()` static method. It is\n",
"possible, but not mandatory for custom Functions to support both forward\n",
"and backward AD. See the\n",
"[documentation](https://pytorch.org/docs/master/notes/extending.html#forward-mode-ad)\n",
"for more information.\n"
]
},
{
"cell_type": "code",
"execution_count": null,
"metadata": {
"collapsed": false
},
"outputs": [],
"source": [
"class Fn(torch.autograd.Function):\n",
" @staticmethod\n",
" def forward(ctx, foo):\n",
" result = torch.exp(foo)\n",
" # Tensors stored in ``ctx`` can be used in the subsequent forward grad\n",
" # computation.\n",
" ctx.result = result\n",
" return result\n",
"\n",
" @staticmethod\n",
" def jvp(ctx, gI):\n",
" gO = gI * ctx.result\n",
" # If the tensor stored in`` ctx`` will not also be used in the backward pass,\n",
" # one can manually free it using ``del``\n",
" del ctx.result\n",
" return gO\n",
"\n",
"fn = Fn.apply\n",
"\n",
"primal = torch.randn(10, 10, dtype=torch.double, requires_grad=True)\n",
"tangent = torch.randn(10, 10)\n",
"\n",
"with fwAD.dual_level():\n",
" dual_input = fwAD.make_dual(primal, tangent)\n",
" dual_output = fn(dual_input)\n",
" jvp = fwAD.unpack_dual(dual_output).tangent\n",
"\n",
"# It is important to use ``autograd.gradcheck`` to verify that your\n",
"# custom autograd Function computes the gradients correctly. By default,\n",
"# ``gradcheck`` only checks the backward-mode (reverse-mode) AD gradients. Specify\n",
"# ``check_forward_ad=True`` to also check forward grads. If you did not\n",
"# implement the backward formula for your function, you can also tell ``gradcheck``\n",
"# to skip the tests that require backward-mode AD by specifying\n",
"# ``check_backward_ad=False``, ``check_undefined_grad=False``, and\n",
"# ``check_batched_grad=False``.\n",
"torch.autograd.gradcheck(Fn.apply, (primal,), check_forward_ad=True,\n",
" check_backward_ad=False, check_undefined_grad=False,\n",
" check_batched_grad=False)"
]
},
{
"cell_type": "markdown",
"metadata": {},
"source": [
"Functional API (beta)\n",
"=====================\n",
"\n",
"We also offer a higher-level functional API in functorch for computing\n",
"Jacobian-vector products that you may find simpler to use depending on\n",
"your use case.\n",
"\n",
"The benefit of the functional API is that there isn\\'t a need to\n",
"understand or use the lower-level dual tensor API and that you can\n",
"compose it with other [functorch transforms (like\n",
"vmap)](https://pytorch.org/functorch/stable/notebooks/jacobians_hessians.html);\n",
"the downside is that it offers you less control.\n",
"\n",
"Note that the remainder of this tutorial will require functorch\n",
"() to run. Please find\n",
"installation instructions at the specified link.\n"
]
},
{
"cell_type": "code",
"execution_count": null,
"metadata": {
"collapsed": false
},
"outputs": [],
"source": [
"import functorch as ft\n",
"\n",
"primal0 = torch.randn(10, 10)\n",
"tangent0 = torch.randn(10, 10)\n",
"primal1 = torch.randn(10, 10)\n",
"tangent1 = torch.randn(10, 10)\n",
"\n",
"def fn(x, y):\n",
" return x ** 2 + y ** 2\n",
"\n",
"# Here is a basic example to compute the JVP of the above function.\n",
"# The ``jvp(func, primals, tangents)`` returns ``func(*primals)`` as well as the\n",
"# computed Jacobian-vector product (JVP). Each primal must be associated with a tangent of the same shape.\n",
"primal_out, tangent_out = ft.jvp(fn, (primal0, primal1), (tangent0, tangent1))\n",
"\n",
"# ``functorch.jvp`` requires every primal to be associated with a tangent.\n",
"# If we only want to associate certain inputs to `fn` with tangents,\n",
"# then we'll need to create a new function that captures inputs without tangents:\n",
"primal = torch.randn(10, 10)\n",
"tangent = torch.randn(10, 10)\n",
"y = torch.randn(10, 10)\n",
"\n",
"import functools\n",
"new_fn = functools.partial(fn, y=y)\n",
"primal_out, tangent_out = ft.jvp(new_fn, (primal,), (tangent,))"
]
},
{
"cell_type": "markdown",
"metadata": {},
"source": [
"Using the functional API with Modules\n",
"=====================================\n",
"\n",
"To use `nn.Module` with `functorch.jvp` to compute Jacobian-vector\n",
"products with respect to the model parameters, we need to reformulate\n",
"the `nn.Module` as a function that accepts both the model parameters and\n",
"inputs to the module.\n"
]
},
{
"cell_type": "code",
"execution_count": null,
"metadata": {
"collapsed": false
},
"outputs": [],
"source": [
"model = nn.Linear(5, 5)\n",
"input = torch.randn(16, 5)\n",
"tangents = tuple([torch.rand_like(p) for p in model.parameters()])\n",
"\n",
"# Given a ``torch.nn.Module``, ``ft.make_functional_with_buffers`` extracts the state\n",
"# (``params`` and buffers) and returns a functional version of the model that\n",
"# can be invoked like a function.\n",
"# That is, the returned ``func`` can be invoked like\n",
"# ``func(params, buffers, input)``.\n",
"# ``ft.make_functional_with_buffers`` is analogous to the ``nn.Modules`` stateless API\n",
"# that you saw previously and we're working on consolidating the two.\n",
"func, params, buffers = ft.make_functional_with_buffers(model)\n",
"\n",
"# Because ``jvp`` requires every input to be associated with a tangent, we need to\n",
"# create a new function that, when given the parameters, produces the output\n",
"def func_params_only(params):\n",
" return func(params, buffers, input)\n",
"\n",
"model_output, jvp_out = ft.jvp(func_params_only, (params,), (tangents,))"
]
},
{
"cell_type": "markdown",
"metadata": {},
"source": [
"\\[0\\] \n"
]
}
],
"metadata": {
"kernelspec": {
"display_name": "Python 3",
"language": "python",
"name": "python3"
},
"language_info": {
"codemirror_mode": {
"name": "ipython",
"version": 3
},
"file_extension": ".py",
"mimetype": "text/x-python",
"name": "python",
"nbconvert_exporter": "python",
"pygments_lexer": "ipython3",
"version": "3.10.14"
}
},
"nbformat": 4,
"nbformat_minor": 0
}