# Copyright (c) Facebook, Inc. and its affiliates.
# All rights reserved.
#
# This source code is licensed under the BSD-style license found in the
# LICENSE file in the root directory of this source tree.
from typing import Any, Callable, Union, Tuple, List
import torch
from functools import partial, wraps
import contextlib
from torch.utils._pytree import tree_flatten, tree_unflatten, tree_map
from .pytree_hacks import tree_map_, treespec_pprint
import torch.autograd.forward_ad as fwAD
from .vmap import vmap
from functorch._C import (
_wrap_for_grad,
_unwrap_for_grad,
_grad_increment_nesting,
_grad_decrement_nesting,
)
argnums_t = Union[int, Tuple[int, ...]]
def _create_differentiable(inps, level=None):
def create_differentiable(x):
if isinstance(x, torch.Tensor):
return x.requires_grad_()
raise ValueError(f'Thing passed to transform API must be Tensor, '
f'got {type(x)}')
return tree_map(create_differentiable, inps)
def _undo_create_differentiable(inps, level=None):
def unwrap_tensors(x):
if isinstance(x, torch.Tensor):
return _unwrap_for_grad(x, level)
# TODO: Remove the following hack for namedtuples
if isinstance(x, tuple):
return tree_map(unwrap_tensors, tuple(x))
raise RuntimeError(f"Expected tensors, got unsupported type {type(x)}")
return tree_map(unwrap_tensors, inps)
def _is_differentiable(maybe_tensor):
if not isinstance(maybe_tensor, torch.Tensor):
return False
return maybe_tensor.requires_grad
def _any_differentiable(tensor_or_tuple_of_tensors):
flat_args, _ = tree_unflatten(tensor_or_tuple_of_tensors)
return any(tuple(map(_is_differentiable, flat_args)))
def _wrap_tensor_for_grad(maybe_tensor, level):
if not isinstance(maybe_tensor, torch.Tensor):
return maybe_tensor
return _wrap_for_grad(maybe_tensor, level)
def _wrap_all_tensors(tensor_pytree, level):
return tree_map(partial(_wrap_tensor_for_grad, level=level), tensor_pytree)
def _as_tuple(val):
if isinstance(val, tuple):
return val
return (val,)
# Version of autograd.grad that handles outputs that don't depend on inputs
def _autograd_grad(outputs, inputs, grad_outputs=None, retain_graph=False, create_graph=True):
if grad_outputs is None:
diff_outputs = tuple(out for out in outputs if out.requires_grad)
else:
result = tuple((out, go) for out, go in zip(outputs, grad_outputs) if out.requires_grad)
if len(result) == 0:
diff_outputs, grad_outputs = (), ()
else:
diff_outputs, grad_outputs = zip(*result)
if len(diff_outputs) == 0:
return tuple(torch.zeros_like(inp) for inp in inputs)
grad_inputs = torch.autograd.grad(diff_outputs, inputs, grad_outputs,
retain_graph=retain_graph,
create_graph=create_graph,
allow_unused=True)
grad_inputs = tuple(torch.zeros_like(inp) if gi is None else gi
for gi, inp in zip(grad_inputs, inputs))
return grad_inputs
# NOTE [grad and vjp interaction with no_grad]
#
# def f(x):
# with torch.no_grad():
# c = x ** 2
# return x - c
#
# The thing to consider is if enable_grad is on/off before grad gets called.
#
# Case 1: enable_grad is on.
# grad(f)(x)
# In this case, `grad` should respect the inner torch.no_grad.
#
# Case 2: enable_grad is off
# with torch.no_grad():
# grad(f)(x)
# In this case, `grad` should respect the inner torch.no_grad, but not the
# outer one. This is because `grad` is a "function transform": its result
# should not depend on the result of a context manager outside of `f`.
#
# This gives us the following desired behavior:
# - (nested) grad transforms must obey torch.no_grad inside them
# - (nested) grad transforms should not obey torch.no_grad outside them
#
# To achieve this behavior, upon entering grad/vjp:
# - we save the current ("previous") is_grad_enabled (*)
# - we unconditionally enable grad.
#
# Inside DynamicLayerBackFallback, when we're temporarily popping `grad` layer
# off the stack:
# - if grad_mode is disabled, then we do nothing. (there is a torch.no_grad
# active, all subsequent grad transforms must obey it).
# - if grad_mode is enabled, and the previous is_grad_enabled (*) is False,
# then we temporarily restore the previous `is_grad_enabled`. This is
# because we're crossing the boundary from a `grad` outside the
# no_grad to a `grad` inside the no_grad.
#
# NB: vjp has some interesting behavior because the vjp's callable can be called
# under a different grad_mode than the forward computation...
#
# TODO: forward-mode AD: does it also respect no_grad? What does that mean
# for our jvp transform?
# How do we increment and decrement the nesting? I don't think we can.
[docs]def vjp(func: Callable, *primals, has_aux: bool = False):
"""
Standing for the vector-Jacobian product, returns a tuple containing the
results of :attr:`func` applied to :attr:`primals` and a function that, when
given ``cotangents``, computes the reverse-mode Jacobian of :attr:`func` with
respect to :attr:`primals` times ``cotangents``.
Args:
func (Callable): A Python function that takes one or more arguments. Must
return one or more Tensors.
primals (Tensors): Positional arguments to :attr:`func` that must all be
Tensors. The returned function will also be computing the
derivative with respect to these arguments
has_aux (bool): Flag indicating that :attr:`func` returns a
``(output, aux)`` tuple where the first element is the output of
the function to be differentiated and the second element is
other auxiliary objects that will not be differentiated.
Default: False.
Returns:
Returns a ``(output, vjp_fn)`` tuple containing the output of :attr:`func`
applied to :attr:`primals` and a function that computes the vjp of
:attr:`func` with respect to all :attr:`primals` using the cotangents passed
to the returned function. If ``has_aux is True``, then instead returns a
``(output, vjp_fn, aux)`` tuple.
The returned ``vjp_fn`` function will return a tuple of each VJP.
When used in simple cases, :func:`vjp` behaves the same as :func:`grad`
>>> x = torch.randn([5])
>>> f = lambda x: x.sin().sum()
>>> (_, vjpfunc) = functorch.vjp(f, x)
>>> grad = vjpfunc(torch.tensor(1.))[0]
>>> assert torch.allclose(grad, functorch.grad(f)(x))
However, :func:`vjp` can support functions with multiple outputs by
passing in the cotangents for each of the outputs
>>> x = torch.randn([5])
>>> f = lambda x: (x.sin(), x.cos())
>>> (_, vjpfunc) = functorch.vjp(f, x)
>>> vjps = vjpfunc((torch.ones([5]), torch.ones([5])))
>>> assert torch.allclose(vjps[0], x.cos() + -x.sin())
:func:`vjp` can even support outputs being Python structs
>>> x = torch.randn([5])
>>> f = lambda x: {'first': x.sin(), 'second': x.cos()}
>>> (_, vjpfunc) = functorch.vjp(f, x)
>>> cotangents = {'first': torch.ones([5]), 'second': torch.ones([5])}
>>> vjps = vjpfunc(cotangents)
>>> assert torch.allclose(vjps[0], x.cos() + -x.sin())
The function returned by :func:`vjp` will compute the partials with
respect to each of the :attr:`primals`
>>> x, y = torch.randn([5, 4]), torch.randn([4, 5])
>>> (_, vjpfunc) = functorch.vjp(torch.matmul, x, y)
>>> cotangents = torch.randn([5, 5])
>>> vjps = vjpfunc(cotangents)
>>> assert len(vjps) == 2
>>> assert torch.allclose(vjps[0], torch.matmul(cotangents, y.transpose(0, 1)))
>>> assert torch.allclose(vjps[1], torch.matmul(x.transpose(0, 1), cotangents))
:attr:`primals` are the positional arguments for :attr:`f`. All kwargs use their
default value
>>> x = torch.randn([5])
>>> def f(x, scale=4.):
>>> return x * 4.
>>>
>>> (_, vjpfunc) = functorch.vjp(f, x)
>>> vjps = vjpfunc(torch.ones_like(x))
>>> assert torch.allclose(vjps[0], torch.full(x.shape, 4.))
.. note::
Using PyTorch ``torch.no_grad`` together with ``vjp``.
Case 1: Using ``torch.no_grad`` inside a function:
>>> def f(x):
>>> with torch.no_grad():
>>> c = x ** 2
>>> return x - c
In this case, ``vjp(f)(x)`` will respect the inner ``torch.no_grad``.
Case 2: Using ``vjp`` inside ``torch.no_grad`` context manager:
>>> with torch.no_grad():
>>> vjp(f)(x)
In this case, ``vjp`` will respect the inner ``torch.no_grad``, but not the
outer one. This is because ``vjp`` is a "function transform": its result
should not depend on the result of a context manager outside of ``f``.
"""
level = _grad_increment_nesting()
try:
# See NOTE [grad and vjp interaction with no_grad]
with torch.enable_grad():
primals = _wrap_all_tensors(primals, level)
diff_primals = _create_differentiable(primals, level)
primals_out = func(*diff_primals)
if has_aux:
if not (isinstance(primals_out, tuple) and len(primals_out) == 2):
raise RuntimeError(
"vjp(f, *primals): output of function f should be a tuple: (output, aux) "
"if has_aux is True"
)
primals_out, aux = primals_out
aux = _undo_create_differentiable(aux, level)
flat_primals_out, primals_out_spec = tree_flatten(primals_out)
assert_non_empty_tensor_output(flat_primals_out, 'vjp(f, *primals)')
flat_diff_primals, primals_spec = tree_flatten(diff_primals)
results = _undo_create_differentiable(primals_out, level)
for primal_out in flat_primals_out:
assert isinstance(primal_out, torch.Tensor)
if primal_out.is_floating_point() or primal_out.is_complex():
continue
raise RuntimeError("vjp(f, ...): All outputs of f must be "
"floating-point or complex Tensors, got Tensor "
f"with dtype {primal_out.dtype}")
def wrapper(cotangents, retain_graph=True, create_graph=None):
if create_graph is None:
create_graph = torch.is_grad_enabled()
flat_cotangents, cotangents_spec = tree_flatten(cotangents)
if primals_out_spec != cotangents_spec:
raise RuntimeError(
f'Expected pytree structure of cotangents to be the same '
f'as pytree structure of outputs to the function. '
f'cotangents: {treespec_pprint(cotangents_spec)}, '
f'primal output: {treespec_pprint(primals_out_spec)}')
result = _autograd_grad(flat_primals_out, flat_diff_primals, flat_cotangents,
retain_graph=retain_graph, create_graph=create_graph)
return tree_unflatten(result, primals_spec)
finally:
_grad_decrement_nesting()
if has_aux:
return results, wrapper, aux
else:
return results, wrapper
def _safe_zero_index(x):
assert len(x) == 1
return x[0]
[docs]def jacrev(func: Callable, argnums: Union[int, Tuple[int]] = 0, *, has_aux=False):
"""
Computes the Jacobian of :attr:`func` with respect to the arg(s) at index
:attr:`argnum` using reverse mode autodiff
Args:
func (function): A Python function that takes one or more arguments,
one of which must be a Tensor, and returns one or more Tensors
argnums (int or Tuple[int]): Optional, integer or tuple of integers,
saying which arguments to get the Jacobian with respect to.
Default: 0.
has_aux (bool): Flag indicating that :attr:`func` returns a
``(output, aux)`` tuple where the first element is the output of
the function to be differentiated and the second element is
auxiliary objects that will not be differentiated.
Default: False.
Returns:
Returns a function that takes in the same inputs as :attr:`func` and
returns the Jacobian of :attr:`func` with respect to the arg(s) at
:attr:`argnums`. If ``has_aux is True``, then the returned function
instead returns a ``(jacobian, aux)`` tuple where ``jacobian``
is the Jacobian and ``aux`` is auxiliary objects returned by :attr:`func`.
A basic usage with a pointwise, unary operation will give a diagonal array
as the Jacobian
>>> from functorch import jacrev
>>> x = torch.randn(5)
>>> jacobian = jacrev(torch.sin)(x)
>>> expected = torch.diag(torch.cos(x))
>>> assert torch.allclose(jacobian, expected)
If you would like to compute the output of the function as well as the
jacobian of the function, use the ``has_aux`` flag to return the output
as an auxiliary object:
>>> from functorch import jacrev
>>> x = torch.randn(5)
>>>
>>> def f(x):
>>> return x.sin()
>>>
>>> def g(x):
>>> result = f(x)
>>> return result, result
>>>
>>> jacobian_f, f_x = jacrev(g, has_aux=True)(x)
>>> assert torch.allclose(f_x, f(x))
:func:`jacrev` can be composed with vmap to produce batched
Jacobians:
>>> from functorch import jacrev, vmap
>>> x = torch.randn(64, 5)
>>> jacobian = vmap(jacrev(torch.sin))(x)
>>> assert jacobian.shape == (64, 5, 5)
Additionally, :func:`jacrev` can be composed with itself to produce
Hessians
>>> from functorch import jacrev
>>> def f(x):
>>> return x.sin().sum()
>>>
>>> x = torch.randn(5)
>>> hessian = jacrev(jacrev(f))(x)
>>> assert torch.allclose(hessian, torch.diag(-x.sin()))
By default, :func:`jacrev` computes the Jacobian with respect to the first
input. However, it can compute the Jacboian with respect to a different
argument by using :attr:`argnums`:
>>> from functorch import jacrev
>>> def f(x, y):
>>> return x + y ** 2
>>>
>>> x, y = torch.randn(5), torch.randn(5)
>>> jacobian = jacrev(f, argnums=1)(x, y)
>>> expected = torch.diag(2 * y)
>>> assert torch.allclose(jacobian, expected)
Additionally, passing a tuple to :attr:`argnums` will compute the Jacobian
with respect to multiple arguments
>>> from functorch import jacrev
>>> def f(x, y):
>>> return x + y ** 2
>>>
>>> x, y = torch.randn(5), torch.randn(5)
>>> jacobian = jacrev(f, argnums=(0, 1))(x, y)
>>> expectedX = torch.diag(torch.ones_like(x))
>>> expectedY = torch.diag(2 * y)
>>> assert torch.allclose(jacobian[0], expectedX)
>>> assert torch.allclose(jacobian[1], expectedY)
.. note::
Using PyTorch ``torch.no_grad`` together with ``jacrev``.
Case 1: Using ``torch.no_grad`` inside a function:
>>> def f(x):
>>> with torch.no_grad():
>>> c = x ** 2
>>> return x - c
In this case, ``jacrev(f)(x)`` will respect the inner ``torch.no_grad``.
Case 2: Using ``jacrev`` inside ``torch.no_grad`` context manager:
>>> with torch.no_grad():
>>> jacrev(f)(x)
In this case, ``jacrev`` will respect the inner ``torch.no_grad``, but not the
outer one. This is because ``jacrev`` is a "function transform": its result
should not depend on the result of a context manager outside of ``f``.
"""
@wraps(func)
def wrapper_fn(*args):
f_wrapper, primals = _argnums_partial(func, args, argnums)
vjp_out = vjp(f_wrapper, *primals, has_aux=has_aux)
if has_aux:
output, vjp_fn, aux = vjp_out
else:
output, vjp_fn = vjp_out
# See NOTE: [Computing jacobian with vmap and vjp for multiple outputs]
flat_output, output_spec = tree_flatten(output)
# NB: vjp already checks that all outputs are tensors
# Step 1: Construct grad_outputs by splitting the standard basis
flat_output_numels = tuple(out.numel() for out in flat_output)
flat_basis = _construct_standard_basis_for(flat_output, flat_output_numels)
basis = tree_unflatten(flat_basis, output_spec)
results = vmap(vjp_fn)(basis)
flat_primals, primals_spec = tree_flatten(primals)
flat_results, results_spec = tree_flatten(results)
# Step 2: The returned jacobian is one big tensor per input. In this step,
# we split each Tensor by output.
flat_results = [result.split(flat_output_numels, dim=0) for result in flat_results]
flat_input_flat_output = [
tuple(split.view(out.shape + primal.shape)
for split, out in zip(splits, flat_output))
for splits, primal in zip(flat_results, flat_primals)
]
# Step 3: Right now, `jacobian` is a List[List[Tensor]].
# The outer List corresponds to the number of primals,
# the inner List corresponds to the number of outputs.
# We need to:
# a. Exchange the order of the outer List and inner List
# b. tree_unflatten the inner Lists (which correspond to the primals)
# c. handle the argnums=int case
# d. tree_unflatten the outer List (which corresponds to the outputs)
flat_output_flat_input = tuple(zip(*flat_input_flat_output))
flat_output_input = tuple(tree_unflatten(flat_input, primals_spec)
for flat_input in flat_output_flat_input)
if isinstance(argnums, int):
flat_output_input = tuple(_safe_zero_index(flat_input)
for flat_input in flat_output_input)
output_input = tree_unflatten(flat_output_input, output_spec)
if has_aux:
return output_input, aux
return output_input
return wrapper_fn
# NOTE: [Computing jacobian with vmap and vjp for multiple outputs]
#
# Let's consider f(x) = (x**2, x.sum()) and let x = torch.randn(3).
# It turns out we can compute the jacobian of this function with a single
# call to autograd.grad by using vmap over the correct grad_outputs.
#
# Firstly, one way to compute the jacobian is to stack x**2 and x.sum()
# into a 4D vector. E.g., use g(x) = torch.stack([x**2, x.sum()])
#
# To get the first row of the jacobian, we call
# >>> autograd.grad(g(x), x, grad_outputs=torch.tensor([1, 0, 0, 0]))
# To get the 2nd row of the jacobian, we call
# >>> autograd.grad(g(x), x, grad_outputs=torch.tensor([0, 1, 0, 0]))
# and so on.
#
# Using vmap, we can vectorize all 4 of these computations into one by
# passing the standard basis for R^4 as the grad_output.
# vmap(partial(autograd.grad, g(x), x))(torch.eye(4)).
#
# Now, how do we compute the jacobian *without stacking the output*?
# We can just split the standard basis across the outputs. So to
# compute the jacobian of f(x), we'd use
# >>> autograd.grad(f(x), x, grad_outputs=_construct_standard_basis_for(...))
# The grad_outputs looks like the following:
# ( torch.tensor([[1, 0, 0],
# [0, 1, 0],
# [0, 0, 1],
# [0, 0, 0]]),
# torch.tensor([[0],
# [0],
# [0],
# [1]]) )
#
# But we're not done yet!
# >>> vmap(partial(autograd.grad(f(x), x, grad_outputs=...)))
# returns a Tensor of shape [4, 3]. We have to remember to split the
# jacobian of shape [4, 3] into two:
# - one of shape [3, 3] for the first output
# - one of shape [ 3] for the second output
def _construct_standard_basis_for(tensors, tensor_numels):
# This function:
# - constructs a N=sum(tensor_numels) standard basis. i.e. an NxN identity matrix.
# - Splits the identity matrix into chunks with each chunk size determined by `tensor_numels`.
# - Each chunk corresponds to one tensor. The chunk has the same dtype and
# device as the tensor
#
# For example, with tensor_numels = [1, 2, 1], this function returns:
# ( tensor([[1], tensor([[0, 0], tensor([[0],
# [0], [1, 0], [0],
# [0], [0, 1], [0],
# [0]]) , [0, 0]]) , [1]]) )
#
# Precondition: tensor_numels == tuple(tensor.numel() for tensor in tensors)
# Precondition: tensors always has at least one element.
#
# See NOTE: [Computing jacobian with vmap and grad for multiple tensors]
# for context behind this function.
assert len(tensors) == len(tensor_numels)
assert len(tensors) > 0
total_numel = sum(tensor_numels)
diag_start_indices = (0, *torch.tensor(tensor_numels).cumsum(dim=0)[:-1].neg().unbind())
chunks = tuple(tensor.new_zeros(total_numel, tensor_numel)
for tensor, tensor_numel in zip(tensors, tensor_numels))
for chunk, diag_start_idx in zip(chunks, diag_start_indices):
chunk.diagonal(diag_start_idx).fill_(1)
chunks = tuple(chunk.view(total_numel, *tensor.shape)
for chunk, tensor in zip(chunks, tensors))
return chunks
def _validate_and_wrap_argnum(argnum, num_args):
if not isinstance(argnum, int):
raise RuntimeError(f'argnum must be int, got: {type(argnum)}')
if argnum >= 0 and argnum < num_args:
return argnum
if argnum < 0 and argnum >= -num_args:
return argnum + num_args
raise RuntimeError(f'Got argnum={argnum}, but only {num_args} positional inputs')
def _check_unique_non_empty(argnums):
if isinstance(argnums, tuple):
if len(argnums) == 0:
raise RuntimeError("argnums must be non-empty")
if len(set(argnums)) != len(argnums):
raise RuntimeError(f"argnums elements must be unique, got {argnums}")
def _replace_args(old_args, new_args, argnums):
if isinstance(argnums, int):
if len(new_args) != 1:
raise RuntimeError(f'new_args should be of size 1, was of size {len(new_args)}')
return tuple(new_args[0] if i == argnums else old_args[i] for i in range(len(old_args)))
if isinstance(argnums, tuple):
if len(new_args) != len(argnums):
raise RuntimeError(
"new_args should have the same size as argnums. "
f"Argnums size {len(argnums)}, new_args size {len(new_args)}")
def get_right_elem(i):
return new_args[argnums.index(i)] if i in argnums else old_args[i]
return tuple(get_right_elem(i) for i in range(len(old_args)))
raise RuntimeError(f'argnums must be int or Tuple[int, ...], got: {type(argnums)}')
def _validate_and_wrap_argnums(argnums, num_args):
if isinstance(argnums, int):
return _validate_and_wrap_argnum(argnums, num_args)
if isinstance(argnums, tuple):
return tuple(_validate_and_wrap_argnum(argnum, num_args) for argnum in argnums)
raise AssertionError("Should never get here")
def _slice_argnums(args, argnums, as_tuple=True):
if not isinstance(argnums, int) and not isinstance(argnums, tuple):
raise RuntimeError(f'argnums must be int or Tuple[int, ...], got: {type(argnums)}')
argnums = _validate_and_wrap_argnums(argnums, len(args))
_check_unique_non_empty(argnums)
if isinstance(argnums, int):
if as_tuple:
return (args[argnums],)
else:
return args[argnums]
return tuple(args[i] for i in argnums)
def _argnums_partial(f, args, argnums):
def f_wrapper(*wrapper_args):
replaced_args = _replace_args(args, wrapper_args, argnums)
return f(*replaced_args)
wrapper_args = _slice_argnums(args, argnums)
wrapper_args = wrapper_args if isinstance(wrapper_args, tuple) else (wrapper_args, )
return (f_wrapper, wrapper_args)
JVP_NESTING = 0
@contextlib.contextmanager
def noop():
yield
def assert_flat_tuple_of_tensors(elts: Any, api: str, argname: str) -> None:
if not isinstance(elts, tuple):
raise RuntimeError(
f'{api}: Expected {argname} to be a tuple of Tensors, got {type(elts)}')
for elt in elts:
if isinstance(elt, torch.Tensor):
continue
raise RuntimeError(
f'{api}: Expected {argname} to be a tuple of Tensors, got '
f'a tuple with an element of type {type(elt)}')
if len(elts) == 0:
raise RuntimeError(
f'{api}: Expected {argname} to be a non-empty tuple of Tensors.')
def assert_non_empty_tensor_output(output: List[Any], api: str) -> None:
if output == [None] or len(output) < 1:
raise RuntimeError(
f'{api}: Expected f to be a function that has non-empty output (got output = {output})'
)
for o in output:
if not isinstance(o, torch.Tensor):
raise RuntimeError(
f'{api}: expected f(*primals) to return only tensors'
f', got unsupported type {type(o)}'
)
def assert_output_is_tensor_or_tensors(output: Any, api: str) -> None:
if isinstance(output, torch.Tensor):
return
if not isinstance(output, tuple):
raise RuntimeError(
f'{api}: Expected output of f to be a Tensor or Tensors, got '
f'{type(output)}')
if len(output) == 0:
raise RuntimeError(
f'{api}: Expected output of f to be a non-empty tuple of Tensors.')
for out in output:
if isinstance(out, torch.Tensor):
continue
raise RuntimeError(
f'{api}: Expected output of f to be a Tensor or Tensors, got '
f'{type(out)} as an output')
def assert_non_empty_list_of_tensors(output: List[torch.Tensor], api: str, argname: str) -> None:
if len(output) == 0:
raise RuntimeError(
f'{api}: Expected {argname} to contain at least one Tensor.')
for out in output:
if isinstance(out, torch.Tensor):
continue
raise RuntimeError(
f'{api}: Expected {argname} to only contain Tensors, got '
f'{type(out)}')
jvp_str = 'jvp(f, primals, tangents)'
def safe_unpack_dual(dual, strict):
if not isinstance(dual, torch.Tensor):
raise RuntimeError(
f'{jvp_str}: expected f(*args) to return only tensors'
f', got unsupported type {type(dual)}'
)
primal, tangent = fwAD.unpack_dual(dual)
if tangent is None:
if strict:
raise RuntimeError(
'jvp(f, primals, tangents, strict=True): '
'The output of f is independent of '
'the inputs. This is not allowed with strict=True.')
tangent = torch.zeros_like(primal)
return primal, tangent
[docs]def jvp(func: Callable, primals: Any, tangents: Any, *, strict: bool = False, has_aux: bool = False):
"""
Standing for the Jacobian-vector product, returns a tuple containing
the output of `func(*primals)` and the "Jacobian of ``func`` evaluated at
``primals``" times ``tangents``. This is also known as forward-mode autodiff.
Args:
func (function): A Python function that takes one or more arguments,
one of which must be a Tensor, and returns one or more Tensors
primals (Tensors): Positional arguments to :attr:`func` that must all be
Tensors. The returned function will also be computing the
derivative with respect to these arguments
tangents (Tensors): The "vector" for which Jacobian-vector-product is
computed. Must be the same structure and sizes as the inputs to
``func``.
has_aux (bool): Flag indicating that :attr:`func` returns a
``(output, aux)`` tuple where the first element is the output of
the function to be differentiated and the second element is
other auxiliary objects that will not be differentiated.
Default: False.
Returns:
Returns a ``(output, jvp_out)`` tuple containing the output of ``func``
evaluated at ``primals`` and the Jacobian-vector product.
If ``has_aux is True``, then instead returns a ``(output, jvp_out, aux)`` tuple.
.. warning::
PyTorch's forward-mode AD coverage on operators is not very good at the
moment. You may see this API error out with "forward-mode AD not
implemented for operator X". If so, please file us a bug report and we
will prioritize it.
jvp is useful when you wish to compute gradients of a function R^1 -> R^N
>>> from functorch import jvp
>>> x = torch.randn([])
>>> f = lambda x: x * torch.tensor([1., 2., 3])
>>> value, grad = jvp(f, (x,), (torch.tensor(1.),))
>>> assert torch.allclose(value, f(x))
>>> assert torch.allclose(grad, torch.tensor([1., 2, 3]))
:func:`jvp` can support functions with multiple inputs by passing in the
tangents for each of the inputs
>>> from functorch import jvp
>>> x = torch.randn(5)
>>> y = torch.randn(5)
>>> f = lambda x, y: (x * y)
>>> _, output = jvp(f, (x, y), (torch.ones(5), torch.ones(5)))
>>> assert torch.allclose(output, x + y)
"""
if not isinstance(primals, tuple):
raise RuntimeError(
f'{jvp_str}: Expected primals to be a tuple. '
f'E.g. it should be valid to call f(*primals).')
flat_primals, primals_spec = tree_flatten(primals)
flat_tangents, tangents_spec = tree_flatten(tangents)
if primals_spec != tangents_spec:
raise RuntimeError(
f'{jvp_str}: Expected primals and tangents to have the same python '
f'structure. For example, if primals is a tuple of 3 tensors, '
f'tangents also must be. Got primals with structure {primals_spec} '
f'and tangents with structure {tangents_spec}')
assert_non_empty_list_of_tensors(flat_primals, jvp_str, 'primals')
assert_non_empty_list_of_tensors(flat_tangents, jvp_str, 'tangents')
level = _grad_increment_nesting()
try:
# Some interesting notes:
# 1. Can't nested jvp of jvp due to forwardAD restrictions
# 2. Seems like we can indeed vmap over this, given some more batch rules
# 3. PyTorch doesn't have a lot of jvp rules implemented right now.
global JVP_NESTING
JVP_NESTING += 1
ctx = fwAD.dual_level if JVP_NESTING == 1 else noop
with ctx():
flat_duals = tuple(fwAD.make_dual(p, t)
for p, t in zip(flat_primals, flat_tangents))
duals = tree_unflatten(flat_duals, primals_spec)
result_duals = func(*duals)
if has_aux:
if not (isinstance(result_duals, tuple) and len(result_duals) == 2):
raise RuntimeError(
f"{jvp_str}: output of function f should be a tuple: (output, aux) "
"if has_aux is True"
)
result_duals, aux = result_duals
aux = _undo_create_differentiable(aux, level)
result_duals, spec = tree_flatten(result_duals)
assert_non_empty_tensor_output(result_duals, jvp_str)
primals_out, tangents_out = \
zip(*[safe_unpack_dual(dual, strict) for dual in result_duals])
primals_out = tree_map(
partial(_undo_create_differentiable, level=level), primals_out)
tangents_out = tree_map(
partial(_undo_create_differentiable, level=level), tangents_out)
primals_out_unflatten = tree_unflatten(primals_out, spec)
tangents_out_unflatten = tree_unflatten(tangents_out, spec)
if has_aux:
return primals_out_unflatten, tangents_out_unflatten, aux
return primals_out_unflatten, tangents_out_unflatten
finally:
_grad_decrement_nesting()
JVP_NESTING -= 1
def safe_unflatten(tensor, dim, shape):
if len(shape) == 0:
assert tensor.shape[dim] == 1
return tensor.squeeze(dim)
return tensor.unflatten(dim, shape)
[docs]def jacfwd(func: Callable, argnums: argnums_t = 0, has_aux: bool = False):
"""
Computes the Jacobian of :attr:`func` with respect to the arg(s) at index
:attr:`argnum` using forward-mode autodiff
Args:
func (function): A Python function that takes one or more arguments,
one of which must be a Tensor, and returns one or more Tensors
argnums (int or Tuple[int]): Optional, integer or tuple of integers,
saying which arguments to get the Jacobian with respect to.
Default: 0.
has_aux (bool): Flag indicating that :attr:`func` returns a
``(output, aux)`` tuple where the first element is the output of
the function to be differentiated and the second element is
auxiliary objects that will not be differentiated.
Default: False.
Returns:
Returns a function that takes in the same inputs as :attr:`func` and
returns the Jacobian of :attr:`func` with respect to the arg(s) at
:attr:`argnums`. If ``has_aux is True``, then the returned function
instead returns a ``(jacobian, aux)`` tuple where ``jacobian``
is the Jacobian and ``aux`` is auxiliary objects returned by :attr:`func`.
.. warning::
PyTorch's forward-mode AD coverage on operators is not very good at the
moment. You may see this API error out with "forward-mode AD not
implemented for operator X". If so, please file us a bug report and we
will prioritize it.
A basic usage with a pointwise, unary operation will give a diagonal array
as the Jacobian
>>> from functorch import jacfwd
>>> x = torch.randn(5)
>>> jacobian = jacfwd(torch.sin)(x)
>>> expected = torch.diag(torch.cos(x))
>>> assert torch.allclose(jacobian, expected)
:func:`jacfwd` can be composed with vmap to produce batched
Jacobians:
>>> from functorch import jacfwd, vmap
>>> x = torch.randn(64, 5)
>>> jacobian = vmap(jacfwd(torch.sin))(x)
>>> assert jacobian.shape == (64, 5, 5)
If you would like to compute the output of the function as well as the
jacobian of the function, use the ``has_aux`` flag to return the output
as an auxiliary object:
>>> from functorch import jacfwd
>>> x = torch.randn(5)
>>>
>>> def f(x):
>>> return x.sin()
>>>
>>> def g(x):
>>> result = f(x)
>>> return result, result
>>>
>>> jacobian_f, f_x = jacfwd(g, has_aux=True)(x)
>>> assert torch.allclose(f_x, f(x))
Additionally, :func:`jacrev` can be composed with itself or :func:`jacrev`
to produce Hessians
>>> from functorch import jacfwd, jacrev
>>> def f(x):
>>> return x.sin().sum()
>>>
>>> x = torch.randn(5)
>>> hessian = jacfwd(jacrev(f))(x)
>>> assert torch.allclose(hessian, torch.diag(-x.sin()))
By default, :func:`jacfwd` computes the Jacobian with respect to the first
input. However, it can compute the Jacboian with respect to a different
argument by using :attr:`argnums`:
>>> from functorch import jacfwd
>>> def f(x, y):
>>> return x + y ** 2
>>>
>>> x, y = torch.randn(5), torch.randn(5)
>>> jacobian = jacfwd(f, argnums=1)(x, y)
>>> expected = torch.diag(2 * y)
>>> assert torch.allclose(jacobian, expected)
Additionally, passing a tuple to :attr:`argnums` will compute the Jacobian
with respect to multiple arguments
>>> from functorch import jacfwd
>>> def f(x, y):
>>> return x + y ** 2
>>>
>>> x, y = torch.randn(5), torch.randn(5)
>>> jacobian = jacfwd(f, argnums=(0, 1))(x, y)
>>> expectedX = torch.diag(torch.ones_like(x))
>>> expectedY = torch.diag(2 * y)
>>> assert torch.allclose(jacobian[0], expectedX)
>>> assert torch.allclose(jacobian[1], expectedY)
"""
def wrapper_fn(*args):
f_wrapper, primals = _argnums_partial(func, args, argnums)
flat_primals, primals_spec = tree_flatten(primals)
flat_primals_numels = tuple(p.numel() for p in flat_primals)
flat_basis = _construct_standard_basis_for(flat_primals, flat_primals_numels)
basis = tree_unflatten(flat_basis, primals_spec)
def push_jvp(basis):
output = jvp(f_wrapper, primals, basis, has_aux=has_aux)
if has_aux:
_, jvp_out, aux = output
return jvp_out, aux
_, jvp_out = output
return jvp_out
results = vmap(push_jvp)(basis)
if has_aux:
results, aux = results
# aux is in the standard basis format, e.g. NxN matrix
# We need to fetch the first element as original `func` output
flat_aux, aux_spec = tree_flatten(aux)
flat_aux = [value[0] for value in flat_aux]
aux = tree_unflatten(flat_aux, aux_spec)
jac_outs, spec = tree_flatten(results)
# Most probably below output check can never raise an error
# as jvp should test the output before
# assert_non_empty_output(jac_outs, 'jacfwd(f, ...)(*args)')
jac_outs_ins = tuple(
tuple(
safe_unflatten(jac_out_in, -1, primal.shape)
for primal, jac_out_in in
zip(flat_primals, jac_out.movedim(0, -1).split(flat_primals_numels, dim=-1))
)
for jac_out in jac_outs
)
jac_outs_ins = tuple(tree_unflatten(jac_ins, primals_spec) for jac_ins in jac_outs_ins)
if isinstance(argnums, int):
jac_outs_ins = tuple(jac_ins[0] for jac_ins in jac_outs_ins)
if has_aux:
return tree_unflatten(jac_outs_ins, spec), aux
return tree_unflatten(jac_outs_ins, spec)
return wrapper_fn
[docs]def hessian(func, argnums=0):
"""
Computes the Hessian of :attr:`func` with respect to the arg(s) at index
:attr:`argnum` via a forward-over-reverse strategy.
The forward-over-reverse strategy (composing ``jacfwd(jacrev(func))``) is
a good default for good performance. It is possible to compute Hessians
through other compositions of :func:`jacfwd` and :func:`jacrev` like
``jacfwd(jacfwd(func))`` or ``jacrev(jacrev(func))``.
Args:
func (function): A Python function that takes one or more arguments,
one of which must be a Tensor, and returns one or more Tensors
argnums (int or Tuple[int]): Optional, integer or tuple of integers,
saying which arguments to get the Hessian with respect to.
Default: 0.
Returns:
Returns a function that takes in the same inputs as :attr:`func` and
returns the Hessian of :attr:`func` with respect to the arg(s) at
:attr:`argnums`.
.. warning::
PyTorch's forward-mode AD coverage on operators is not very good at the
moment. You may see this API error out with "forward-mode AD not
implemented for operator X". If so, please file us a bug report and we
will prioritize it.
A basic usage with a R^N -> R^1 function gives a N x N Hessian:
>>> from functorch import hessian
>>> def f(x):
>>> return x.sin().sum()
>>>
>>> x = torch.randn(5)
>>> hess = jacfwd(jacrev(f))(x)
>>> assert torch.allclose(hess, torch.diag(-x.sin()))
"""
return jacfwd(jacrev(func, argnums), argnums)
[docs]def grad_and_value(func: Callable, argnums: argnums_t = 0, has_aux: bool = False) -> Callable:
"""
Returns a function to compute a tuple of the gradient and primal, or
forward, computation.
Args:
func (Callable): A Python function that takes one or more arguments.
Must return a single-element Tensor. If specified :attr:`has_aux`
equals ``True``, function can return a tuple of single-element
Tensor and other auxiliary objects: ``(output, aux)``.
argnums (int or Tuple[int]): Specifies arguments to compute gradients
with respect to. :attr:`argnums` can be single integer or tuple of
integers. Default: 0.
has_aux (bool): Flag indicating that :attr:`func` returns a tensor and
other auxiliary objects: ``(output, aux)``. Default: False.
Returns:
Function to compute a tuple of gradients with respect to its inputs
and the forward computation. By default, the output of the function is
a tuple of the gradient tensor(s) with respect to the first argument
and the primal computation. If specified :attr:`has_aux` equals
``True``, tuple of gradients and tuple of the forward computation with
output auxiliary objects is returned. If :attr:`argnums` is a tuple of
integers, a tuple of a tuple of the output gradients with respect to
each :attr:`argnums` value and the forward computation is returned.
See :func:`grad` for examples
"""
@wraps(func)
def wrapper(*args, **kwargs):
level = _grad_increment_nesting()
try:
output, aux, grad_input = None, None, None
# See NOTE [grad and vjp interaction with no_grad]
with torch.enable_grad():
args = _wrap_all_tensors(args, level)
kwargs = _wrap_all_tensors(kwargs, level)
diff_args = _slice_argnums(args, argnums, as_tuple=False)
tree_map_(partial(_create_differentiable, level=level), diff_args)
output = func(*args, **kwargs)
if has_aux:
if not (isinstance(output, tuple) and len(output) == 2):
raise RuntimeError(
"grad_and_value(f)(*args): output of function f should be a tuple: (output, aux) "
"if has_aux is True"
)
output, aux = output
if not isinstance(output, torch.Tensor):
raise RuntimeError('grad_and_value(f)(*args): Expected f(*args) '
f'to return a Tensor, got {type(output)}')
if output.dim() != 0:
raise RuntimeError('grad_and_value(f)(*args): Expected f(*args) '
'to return a scalar Tensor, got tensor with '
f'{output.dim()} dims. Maybe you wanted to '
'use the vjp or jacrev APIs instead?')
flat_diff_args, spec = tree_flatten(diff_args)
# NB: need create_graph so that backward pass isn't run in no_grad mode
flat_outputs = _as_tuple(output)
flat_grad_input = _autograd_grad(flat_outputs, flat_diff_args, create_graph=True)
grad_input = tree_unflatten(flat_grad_input, spec)
grad_input = _undo_create_differentiable(grad_input, level)
output = _undo_create_differentiable(output, level)
if aux is not None:
aux = _undo_create_differentiable(aux, level)
if has_aux:
return grad_input, (output, aux)
return grad_input, output
finally:
_grad_decrement_nesting()
return wrapper
[docs]def grad(func: Callable, argnums: argnums_t = 0, has_aux: bool = False) -> Callable:
"""``grad`` operator helps computing gradients of :attr:`func` with respect to the
input(s) specified by :attr:`argnums`. This operator can be nested to
compute higher-order gradients.
Args:
func (Callable): A Python function that takes one or more arguments.
Must return a single-element Tensor. If specified :attr:`has_aux` equals ``True``,
function can return a tuple of single-element Tensor and other auxiliary objects:
``(output, aux)``.
argnums (int or Tuple[int]): Specifies arguments to compute gradients with respect to.
:attr:`argnums` can be single integer or tuple of integers. Default: 0.
has_aux (bool): Flag indicating that :attr:`func` returns a tensor and other
auxiliary objects: ``(output, aux)``. Default: False.
Returns:
Function to compute gradients with respect to its inputs. By default, the output of
the function is the gradient tensor(s) with respect to the first argument.
If specified :attr:`has_aux` equals ``True``, tuple of gradients and output auxiliary objects
is returned. If :attr:`argnums` is a tuple of integers, a tuple of output gradients with
respect to each :attr:`argnums` value is returned.
Example of using ``grad``:
>>> from functorch import grad
>>> x = torch.randn([])
>>> cos_x = grad(lambda x: torch.sin(x))(x)
>>> assert torch.allclose(cos_x, x.cos())
>>>
>>> # Second-order gradients
>>> neg_sin_x = grad(grad(lambda x: torch.sin(x)))(x)
>>> assert torch.allclose(neg_sin_x, -x.sin())
When composed with ``vmap``, ``grad`` can be used to compute per-sample-gradients:
>>> from functorch import grad
>>> from functorch import vmap
>>> batch_size, feature_size = 3, 5
>>>
>>> def model(weights, feature_vec):
>>> # Very simple linear model with activation
>>> assert feature_vec.dim() == 1
>>> return feature_vec.dot(weights).relu()
>>>
>>> def compute_loss(weights, example, target):
>>> y = model(weights, example)
>>> return ((y - target) ** 2).mean() # MSELoss
>>>
>>> weights = torch.randn(feature_size, requires_grad=True)
>>> examples = torch.randn(batch_size, feature_size)
>>> targets = torch.randn(batch_size)
>>> inputs = (weights, examples, targets)
>>> grad_weight_per_example = vmap(grad(compute_loss), in_dims=(None, 0, 0))(*inputs)
Example of using ``grad`` with :attr:`has_aux` and :attr:`argnums`:
>>> from functorch import grad
>>> def my_loss_func(y, y_pred):
>>> loss_per_sample = (0.5 * y_pred - y) ** 2
>>> loss = loss_per_sample.mean()
>>> return loss, (y_pred, loss_per_sample)
>>>
>>> fn = grad(my_loss_func, argnums=(0, 1), has_aux=True)
>>> y_true = torch.rand(4)
>>> y_preds = torch.rand(4, requires_grad=True)
>>> out = fn(y_true, y_preds)
>>> > output is ((grads w.r.t y_true, grads w.r.t y_preds), (y_pred, loss_per_sample))
.. note::
Using PyTorch ``torch.no_grad`` together with ``grad``.
Case 1: Using ``torch.no_grad`` inside a function:
>>> def f(x):
>>> with torch.no_grad():
>>> c = x ** 2
>>> return x - c
In this case, ``grad(f)(x)`` will respect the inner ``torch.no_grad``.
Case 2: Using ``grad`` inside ``torch.no_grad`` context manager:
>>> with torch.no_grad():
>>> grad(f)(x)
In this case, ``grad`` will respect the inner ``torch.no_grad``, but not the
outer one. This is because ``grad`` is a "function transform": its result
should not depend on the result of a context manager outside of ``f``.
"""
@wraps(func)
def wrapper(*args, **kwargs):
results = grad_and_value(func, argnums, has_aux=has_aux)(*args, **kwargs)
if has_aux:
grad, (_, aux) = results
return grad, aux
grad, _ = results
return grad
return wrapper