torch.func Whirlwind Tour ========================= What is torch.func? ------------------- .. currentmodule:: torch.func torch.func, previously known as functorch, is a library for `JAX `_-like composable function transforms in PyTorch. - A "function transform" is a higher-order function that accepts a numerical function and returns a new function that computes a different quantity. - torch.func has auto-differentiation transforms (``grad(f)`` returns a function that computes the gradient of ``f``), a vectorization/batching transform (``vmap(f)`` returns a function that computes ``f`` over batches of inputs), and others. - These function transforms can compose with each other arbitrarily. For example, composing ``vmap(grad(f))`` computes a quantity called per-sample-gradients that stock PyTorch cannot efficiently compute today. Why composable function transforms? ----------------------------------- There are a number of use cases that are tricky to do in PyTorch today: - computing per-sample-gradients (or other per-sample quantities) - running ensembles of models on a single machine - efficiently batching together tasks in the inner-loop of MAML - efficiently computing Jacobians and Hessians - efficiently computing batched Jacobians and Hessians Composing :func:`vmap`, :func:`grad`, :func:`vjp`, and :func:`jvp` transforms allows us to express the above without designing a separate subsystem for each. What are the transforms? ------------------------ :func:`grad` (gradient computation) ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^ ``grad(func)`` is our gradient computation transform. It returns a new function that computes the gradients of ``func``. It assumes ``func`` returns a single-element Tensor and by default it computes the gradients of the output of ``func`` w.r.t. to the first input. .. code-block:: python import torch from torch.func import grad x = torch.randn([]) cos_x = grad(lambda x: torch.sin(x))(x) assert torch.allclose(cos_x, x.cos()) # Second-order gradients neg_sin_x = grad(grad(lambda x: torch.sin(x)))(x) assert torch.allclose(neg_sin_x, -x.sin()) :func:`vmap` (auto-vectorization) ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^ Note: :func:`vmap` imposes restrictions on the code that it can be used on. For more details, please see :ref:`ux-limitations`. ``vmap(func)(*inputs)`` is a transform that adds a dimension to all Tensor operations in ``func``. ``vmap(func)`` returns a new function that maps ``func`` over some dimension (default: 0) of each Tensor in inputs. vmap is useful for hiding batch dimensions: one can write a function func that runs on examples and then lift it to a function that can take batches of examples with ``vmap(func)``, leading to a simpler modeling experience: .. code-block:: python import torch from torch.func import vmap batch_size, feature_size = 3, 5 weights = torch.randn(feature_size, requires_grad=True) def model(feature_vec): # Very simple linear model with activation assert feature_vec.dim() == 1 return feature_vec.dot(weights).relu() examples = torch.randn(batch_size, feature_size) result = vmap(model)(examples) When composed with :func:`grad`, :func:`vmap` can be used to compute per-sample-gradients: .. code-block:: python from torch.func import vmap batch_size, feature_size = 3, 5 def model(weights,feature_vec): # Very simple linear model with activation assert feature_vec.dim() == 1 return feature_vec.dot(weights).relu() def compute_loss(weights, example, target): y = model(weights, example) return ((y - target) ** 2).mean() # MSELoss weights = torch.randn(feature_size, requires_grad=True) examples = torch.randn(batch_size, feature_size) targets = torch.randn(batch_size) inputs = (weights,examples, targets) grad_weight_per_example = vmap(grad(compute_loss), in_dims=(None, 0, 0))(*inputs) :func:`vjp` (vector-Jacobian product) ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^ The :func:`vjp` transform applies ``func`` to ``inputs`` and returns a new function that computes the vector-Jacobian product (vjp) given some ``cotangents`` Tensors. .. code-block:: python from torch.func import vjp inputs = torch.randn(3) func = torch.sin cotangents = (torch.randn(3),) outputs, vjp_fn = vjp(func, inputs); vjps = vjp_fn(*cotangents) :func:`jvp` (Jacobian-vector product) ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^ The :func:`jvp` transforms computes Jacobian-vector-products and is also known as "forward-mode AD". It is not a higher-order function unlike most other transforms, but it returns the outputs of ``func(inputs)`` as well as the jvps. .. code-block:: python from torch.func import jvp x = torch.randn(5) y = torch.randn(5) f = lambda x, y: (x * y) _, out_tangent = jvp(f, (x, y), (torch.ones(5), torch.ones(5))) assert torch.allclose(out_tangent, x + y) :func:`jacrev`, :func:`jacfwd`, and :func:`hessian` ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^ The :func:`jacrev` transform returns a new function that takes in ``x`` and returns the Jacobian of the function with respect to ``x`` using reverse-mode AD. .. code-block:: python from torch.func import jacrev x = torch.randn(5) jacobian = jacrev(torch.sin)(x) expected = torch.diag(torch.cos(x)) assert torch.allclose(jacobian, expected) :func:`jacrev` can be composed with :func:`vmap` to produce batched jacobians: .. code-block:: python x = torch.randn(64, 5) jacobian = vmap(jacrev(torch.sin))(x) assert jacobian.shape == (64, 5, 5) :func:`jacfwd` is a drop-in replacement for jacrev that computes Jacobians using forward-mode AD: .. code-block:: python from torch.func import jacfwd x = torch.randn(5) jacobian = jacfwd(torch.sin)(x) expected = torch.diag(torch.cos(x)) assert torch.allclose(jacobian, expected) Composing :func:`jacrev` with itself or :func:`jacfwd` can produce hessians: .. code-block:: python def f(x): return x.sin().sum() x = torch.randn(5) hessian0 = jacrev(jacrev(f))(x) hessian1 = jacfwd(jacrev(f))(x) :func:`hessian` is a convenience function that combines jacfwd and jacrev: .. code-block:: python from torch.func import hessian def f(x): return x.sin().sum() x = torch.randn(5) hess = hessian(f)(x)