# Copyright (c) Facebook, Inc. and its affiliates.
# All rights reserved.
#
# This source code is licensed under the BSD-style license found in the
# LICENSE file in the root directory of this source tree.
from typing import Callable, Union, Tuple
import torch
from functools import partial, wraps
import contextlib
from torch.utils._pytree import tree_flatten, tree_unflatten, tree_map
from .pytree_hacks import tree_map_, treespec_pprint
import torch.autograd.forward_ad as fwAD
from .vmap import vmap
from functorch._C import (
_wrap_for_grad,
_unwrap_for_grad,
_grad_increment_nesting,
_grad_decrement_nesting,
)
argnums_t = Union[int, Tuple[int, ...]]
def _create_differentiable(inps, level=None):
def create_differentiable(x):
if isinstance(x, torch.Tensor):
return x.requires_grad_()
raise ValueError(f'Thing passed to transform API must be Tensor, '
f'got {type(x)}')
return tree_map(create_differentiable, inps)
def _undo_create_differentiable(inps, level=None):
def unwrap_tensors(x):
if isinstance(x, torch.Tensor):
return _unwrap_for_grad(x, level)
# TODO: Remove the following hack for namedtuples
if isinstance(x, tuple):
return tree_map(unwrap_tensors, tuple(x))
raise RuntimeError(f"Expected tensors, got unsupported type {type(x)}")
return tree_map(unwrap_tensors, inps)
def _is_differentiable(maybe_tensor):
if not isinstance(maybe_tensor, torch.Tensor):
return False
return maybe_tensor.requires_grad
def _any_differentiable(tensor_or_tuple_of_tensors):
flat_args, _ = tree_unflatten(tensor_or_tuple_of_tensors)
return any(tuple(map(_is_differentiable, flat_args)))
def _wrap_tensor_for_grad(maybe_tensor, level):
if not isinstance(maybe_tensor, torch.Tensor):
return maybe_tensor
return _wrap_for_grad(maybe_tensor, level)
def _wrap_all_tensors(tensor_pytree, level):
return tree_map(partial(_wrap_tensor_for_grad, level=level), tensor_pytree)
def _as_tuple(val):
if isinstance(val, tuple):
return val
return (val,)
# Version of autograd.grad that handles outputs that don't depend on inputs
def _autograd_grad(outputs, inputs, grad_outputs=None, retain_graph=False, create_graph=True):
if grad_outputs is None:
diff_outputs = tuple(out for out in outputs if out.requires_grad)
else:
result = tuple((out, go) for out, go in zip(outputs, grad_outputs) if out.requires_grad)
if len(result) == 0:
diff_outputs, grad_outputs = (), ()
else:
diff_outputs, grad_outputs = zip(*result)
if len(diff_outputs) == 0:
return tuple(torch.zeros_like(inp) for inp in inputs)
grad_inputs = torch.autograd.grad(diff_outputs, inputs, grad_outputs,
retain_graph=retain_graph,
create_graph=create_graph,
allow_unused=True)
grad_inputs = tuple(torch.zeros_like(inp) if gi is None else gi
for gi, inp in zip(grad_inputs, inputs))
return grad_inputs
# NOTE [grad and vjp interaction with no_grad]
#
# def f(x):
# with torch.no_grad():
# c = x ** 2
# return x - c
#
# The thing to consider is if enable_grad is on/off before grad gets called.
#
# Case 1: enable_grad is on.
# grad(f)(x)
# In this case, `grad` should respect the inner torch.no_grad.
#
# Case 2: enable_grad is off
# with torch.no_grad():
# grad(f)(x)
# In this case, `grad` should respect the inner torch.no_grad, but not the
# outer one. This is because `grad` is a "function transform": its result
# should not depend on the result of a context manager outside of `f`.
#
# This gives us the following desired behavior:
# - (nested) grad transforms must obey torch.no_grad inside them
# - (nested) grad transforms should not obey torch.no_grad outside them
#
# To achieve this behavior, upon entering grad/vjp:
# - we save the current ("previous") is_grad_enabled (*)
# - we unconditionally enable grad.
#
# Inside DynamicLayerBackFallback, when we're temporarily popping `grad` layer
# off the stack:
# - if grad_mode is disabled, then we do nothing. (there is a torch.no_grad
# active, all subsequent grad transforms must obey it).
# - if grad_mode is enabled, and the previous is_grad_enabled (*) is False,
# then we temporarily restore the previous `is_grad_enabled`. This is
# because we're crossing the boundary from a `grad` outside the
# no_grad to a `grad` inside the no_grad.
#
# NB: vjp has some interesting behavior because the vjp's callable can be called
# under a different grad_mode than the forward computation...
#
# TODO: forward-mode AD: does it also respect no_grad? What does that mean
# for our jvp transform?
# How do we increment and decrement the nesting? I don't think we can.
[docs]def vjp(f: Callable, *primals):
"""
Standing for the vector-Jacobian product, returns a tuple containing the
results of :attr:`f` applied to :attr:`primals` and a function that, when
given ``cotangents``, computes the reverse-mode Jacobian of :attr:`f` with
respect to :attr:`primals` times ``cotangents``.
Args:
f (Callable): A Python function that takes one or more arguments. Must
return one or more Tensors.
primals (Tensors): Positional arguments to :attr:`f` that must all be
Tensors. The returned function will also be computing the
derivative with respect to these arguments
Returns:
Returns a tuple containing the output of :attr:`f` applied to
:attr:`primals` and a function that computes the vjp of :attr:`f` with
respect to all :attr:`primals` using the cotangents passed to the
returned function. The returned function will return a tuple of each
VJP
When used in simple cases, :func:`vjp` behaves the same as :func:`grad`
>>> x = torch.randn([5])
>>> f = lambda x: x.sin().sum()
>>> (_, vjpfunc) = functorch.vjp(f, x)
>>> grad = vjpfunc(torch.tensor(1.))[0]
>>> assert torch.allclose(grad, functorch.grad(f)(x))
However, :func:`vjp` can support functions with multiple outputs by
passing in the cotangents for each of the outputs
>>> x = torch.randn([5])
>>> f = lambda x: (x.sin(), x.cos())
>>> (_, vjpfunc) = functorch.vjp(f, x)
>>> vjps = vjpfunc((torch.ones([5]), torch.ones([5])))
>>> assert torch.allclose(vjps[0], x.cos() + -x.sin())
:func:`vjp` can even support outputs being Python structs
>>> x = torch.randn([5])
>>> f = lambda x: {'first': x.sin(), 'second': x.cos()}
>>> (_, vjpfunc) = functorch.vjp(f, x)
>>> cotangents = {'first': torch.ones([5]), 'second': torch.ones([5])}
>>> vjps = vjpfunc((cotangents,))
>>> assert torch.allclose(vjps[0], x.cos() + -x.sin())
The function returned by :func:`vjp` will compute the partials with
respect to each of the :attr:`primals`
>>> x, y = torch.randn([5, 4]), torch.randn([4, 5])
>>> (_, vjpfunc) = functorch.vjp(torch.matmul, x, y)
>>> cotangents = torch.randn([5, 5])
>>> vjps = vjpfunc(cotangents)
>>> assert len(vjps) == 2
>>> assert torch.allclose(vjps[0], torch.matmul(cotangents, y.transpose(0, 1)))
>>> assert torch.allclose(vjps[1], torch.matmul(x.transpose(0, 1), cotangents))
:attr:`primals` are the positional arguments for :attr:`f`. All kwargs use their
default value
>>> x = torch.randn([5])
>>> def f(x, scale=4.):
>>> return x * 4.
>>>
>>> (_, vjpfunc) = functorch.vjp(f, x)
>>> vjps = vjpfunc(torch.ones_like(x))
>>> assert torch.allclose(vjps[0], torch.full(x.shape, 4.))
.. note::
Using PyTorch ``torch.no_grad`` together with ``vjp``.
Case 1: Using ``torch.no_grad`` inside a function:
>>> def f(x):
>>> with torch.no_grad():
>>> c = x ** 2
>>> return x - c
In this case, ``vjp(f)(x)`` will respect the inner ``torch.no_grad``.
Case 2: Using ``vjp`` inside ``torch.no_grad`` context manager:
>>> with torch.no_grad():
>>> vjp(f)(x)
In this case, ``vjp`` will respect the inner ``torch.no_grad``, but not the
outer one. This is because ``vjp`` is a "function transform": its result
should not depend on the result of a context manager outside of ``f``.
"""
level = _grad_increment_nesting()
try:
# See NOTE [grad and vjp interaction with no_grad]
with torch.enable_grad():
primals = _wrap_all_tensors(primals, level)
diff_primals = _create_differentiable(primals, level)
primals_out = f(*diff_primals)
results = _undo_create_differentiable(primals_out, level)
flat_diff_primals, primals_spec = tree_flatten(diff_primals)
flat_primals_out, primals_out_spec = tree_flatten(primals_out)
for primal_out in flat_primals_out:
assert isinstance(primal_out, torch.Tensor)
if primal_out.is_floating_point() or primal_out.is_complex():
continue
raise RuntimeError("vjp(f, ...): All outputs of f must be "
"floating-point or complex Tensors, got Tensor "
f"with dtype {primal_out.dtype}")
def wrapper(cotangents, retain_graph=True, create_graph=None):
if create_graph is None:
create_graph = torch.is_grad_enabled()
flat_cotangents, cotangents_spec = tree_flatten(cotangents)
if primals_out_spec != cotangents_spec:
raise RuntimeError(
f'Expected pytree structure of cotangents to be the same '
f'as pytree structure of outputs to the function. '
f'cotangents: {treespec_pprint(cotangents_spec)}, '
f'primal output: {treespec_pprint(primals_out_spec)}')
result = _autograd_grad(flat_primals_out, flat_diff_primals, flat_cotangents,
retain_graph=retain_graph, create_graph=create_graph)
return tree_unflatten(result, primals_spec)
finally:
_grad_decrement_nesting()
return results, wrapper
[docs]def jacrev(f: Callable, argnums: Union[int, Tuple[int]] = 0):
"""
Computes the Jacobian of :attr:`f` with respect to the arg(s) at index
:attr:`argnum` using reverse mode autodiff
Args:
func (function): A Python function that takes one or more arguments,
one of which must be a Tensor, and returns one or more Tensors
argnums (int or Tuple[int]): Optional, integer or tuple of integers,
saying which arguments to get the Jacobian with respect to.
Default: 0.
Returns:
Returns a function that takes in the same inputs as :attr:`f` and
returns the Jacobian of :attr:`f` with respect to the arg(s) at
:attr:`argnums`
A basic usage with a pointwise, unary operation will give a diagonal array
as the Jacobian
>>> from functorch import jacrev
>>> x = torch.randn(5)
>>> jacobian = jacrev(torch.sin)(x)
>>> expected = torch.diag(torch.cos(x))
>>> assert torch.allclose(jacobian, expected)
:func:`jacrev` can be composed with vmap to produce batched
Jacobians:
>>> from functorch import jacrev
>>> x = torch.randn(64, 5)
>>> jacobian = vmap(jacrev(torch.sin))(x)
>>> assert jacobian.shape == (64, 5, 5)
Additionally, :func:`jacrev` can be composed with itself to produce
Hessians
>>> from functorch import jacrev
>>> def f(x):
>>> return x.sin().sum()
>>>
>>> x = torch.randn(5)
>>> hessian = jacrev(jacrev(f))(x)
>>> assert torch.allclose(hessian, torch.diag(-x.sin()))
By default, :func:`jacrev` computes the Jacobian with respect to the first
input. However, it can compute the Jacboian with respect to a different
argument by using :attr:`argnums`:
>>> from functorch import jacrev
>>> def f(x, y):
>>> return x + y ** 2
>>>
>>> x, y = torch.randn(5), torch.randn(5)
>>> jacobian = jacrev(f, argnums=1)(x, y)
>>> expected = torch.diag(2 * y)
>>> assert torch.allclose(jacobian, expected)
Additionally, passing a tuple to :attr:`argnums` will compute the Jacobian
with respect to multiple arguments
>>> from functorch import jacrev
>>> def f(x, y):
>>> return x + y ** 2
>>>
>>> x, y = torch.randn(5), torch.randn(5)
>>> jacobian = jacrev(f, argnums=(0,1))(x, y)
>>> expectedX = torch.diag(torch.ones_like(x))
>>> expectedY = torch.diag(2 * y)
>>> assert torch.allclose(jacobian[0], expectedX)
>>> assert torch.allclose(jacobian[1], expectedY)
.. note::
Using PyTorch ``torch.no_grad`` together with ``jacrev``.
Case 1: Using ``torch.no_grad`` inside a function:
>>> def f(x):
>>> with torch.no_grad():
>>> c = x ** 2
>>> return x - c
In this case, ``jacrev(f)(x)`` will respect the inner ``torch.no_grad``.
Case 2: Using ``jacrev`` inside ``torch.no_grad`` context manager:
>>> with torch.no_grad():
>>> jacrev(f)(x)
In this case, ``jacrev`` will respect the inner ``torch.no_grad``, but not the
outer one. This is because ``jacrev`` is a "function transform": its result
should not depend on the result of a context manager outside of ``f``.
"""
@wraps(f)
def wrapper_fn(*args):
f_wrapper, primals = _argnums_partial(f, args, argnums)
output, vjp_fn = vjp(f_wrapper, *primals)
assert isinstance(output, torch.Tensor)
# TODO: does jacrev compose with vmap...? the eye call should make it so that it doesn't
basis = torch.eye(output.numel(), dtype=output.dtype, device=output.device) \
.view(output.numel(), *output.shape)
results = vmap(vjp_fn)(basis)
results = tuple(r.view(*output.shape, *p.shape) for (r, p) in zip(results, primals))
return results if len(results) > 1 else results[0]
return wrapper_fn
def _check_unique_non_empty(argnums):
if isinstance(argnums, tuple):
if len(argnums) == 0:
raise RuntimeError("argnums must be non-empty")
if len(set(argnums)) != len(argnums):
raise RuntimeError(f"argnums elements must be unique, got {argnums}")
def _replace_args(old_args, new_args, argnums):
if isinstance(argnums, int):
if len(new_args) == 1:
return tuple(new_args[0] if i == argnums else old_args[i] for i in range(len(old_args)))
else:
raise RuntimeError(f'new_args should be of size 1, was of size {len(new_args)}')
if isinstance(argnums, tuple):
if len(new_args) == len(argnums):
get_right_elem = lambda i : new_args[argnums.index(i)] if i in argnums else old_args[i]
return tuple(get_right_elem(i) for i in range(len(old_args)))
else:
raise RuntimeError("new_args should have the same size as argnums. "
f"Argnums size {len(argnums)}, new_args size {len(new_args)}")
raise RuntimeError(f'argnums must be int or Tuple[int, ...], got: {type(argnums)}')
def _safe_index(args, argnum):
if not isinstance(argnum, int):
raise RuntimeError(f'argnum must be int, got: {type(argnum)}')
if argnum >= 0 and argnum < len(args):
return args[argnum]
raise RuntimeError(f'Got argnum={argnum}, but only {len(args)} positional inputs')
def _slice_argnums(args, argnums):
_check_unique_non_empty(argnums)
if isinstance(argnums, int):
return _safe_index(args, argnums)
if isinstance(argnums, tuple):
return tuple(_safe_index(args, i) for i in argnums)
raise RuntimeError(f'argnums must be int or Tuple[int, ...], got: {type(argnums)}')
def _argnums_partial(f, args, argnums):
def f_wrapper(*wrapper_args):
replaced_args = _replace_args(args, wrapper_args, argnums)
return f(*replaced_args)
wrapper_args = _slice_argnums(args, argnums)
wrapper_args = wrapper_args if isinstance(wrapper_args, tuple) else (wrapper_args, )
return (f_wrapper, wrapper_args)
JVP_NESTING = 0
@contextlib.contextmanager
def noop():
yield
def jvp(f, primals, tangents):
level = _grad_increment_nesting()
try:
# Some interesting notes:
# 1. Can't nested jvp of jvp due to forwardAD restrictions
# 2. Seems like we can indeed vmap over this, given some more batch rules
# 3. PyTorch doesn't have a lot of jvp rules implemented right now.
global JVP_NESTING
JVP_NESTING += 1
ctx = fwAD.dual_level if JVP_NESTING == 1 else noop
with ctx():
# TODO: extend this to any number of primals
assert len(primals) == 1 and len(tangents) == 1
duals = tuple(fwAD.make_dual(p, t) for p, t in zip(primals, tangents))
result_duals = f(*duals)
result_duals, _ = tree_flatten(result_duals)
assert len(result_duals) == 1
primals_out, tangents_out = fwAD.unpack_dual(result_duals[0])
primals_out = _undo_create_differentiable(primals_out, level)
tangents_out = _undo_create_differentiable(tangents_out, level)
return primals_out, tangents_out
finally:
_grad_decrement_nesting()
JVP_NESTING -= 1
def jacfwd(f):
# TODO: This should take more than just a single primal...
def wrapper_fn(primal):
basis = torch.eye(primal.numel(), dtype=primal.dtype, device=primal.device) \
.view(primal.numel(), *primal.shape)
def push_jvp(basis):
_, jvp_out = jvp(f, (primal,), (basis,))
return jvp_out
result = vmap(push_jvp)(basis)
result = result.view(*primal.shape, *primal.shape)
return result
return wrapper_fn
[docs]def grad_and_value(func: Callable, argnums: argnums_t = 0, has_aux: bool = False) -> Callable:
"""
Returns a function to compute a tuple of the gradient and primal, or
forward, computation.
Args:
func (Callable): A Python function that takes one or more arguments.
Must return a single-element Tensor. If specified :attr:`has_aux`
equals ``True``, function can return a tuple of single-element
Tensor and other auxiliary objects: ``(output, aux)``.
argnums (int or Tuple[int]): Specifies arguments to compute gradients
with respect to. :attr:`argnums` can be single integer or tuple of
integers. Default: 0.
has_aux (bool): Flag indicating that :attr:`func` returns a tensor and
other auxiliary objects: ``(output, aux)``. Default: False.
Returns:
Function to compute a tuple of gradients with respect to its inputs
and the forward computation. By default, the output of the function is
a tuple of the gradient tensor(s) with respect to the first argument
and the primal computation. If specified :attr:`has_aux` equals
``True``, tuple of gradients and tuple of the forward computation with
output auxiliary objects is returned. If :attr:`argnums` is a tuple of
integers, a tuple of a tuple of the output gradients with respect to
each :attr:`argnums` value and the forward computation is returned.
See :func:`grad` for examples
"""
@wraps(func)
def wrapper(*args, **kwargs):
level = _grad_increment_nesting()
output, aux, grad_input = None, None, None
try:
# See NOTE [grad and vjp interaction with no_grad]
with torch.enable_grad():
args = _wrap_all_tensors(args, level)
kwargs = _wrap_all_tensors(kwargs, level)
diff_args = _slice_argnums(args, argnums)
tree_map_(partial(_create_differentiable, level=level), diff_args)
output = func(*args, **kwargs)
if has_aux:
output, aux = output
if not isinstance(output, torch.Tensor):
raise RuntimeError('grad_and_value(f)(*args): Expected f(*args) '
f'to return a Tensor, got {type(output)}')
if output.dim() != 0:
raise RuntimeError('grad_and_value(f)(*args): Expected f(*args) '
'to return a scalar Tensor, got tensor with '
f'{output.dim()} dims. Maybe you wanted to '
'use the vjp or jacrev APIs instead?')
flat_diff_args, spec = tree_flatten(diff_args)
# NB: need create_graph so that backward pass isn't run in no_grad mode
flat_outputs = _as_tuple(output)
flat_grad_input = _autograd_grad(flat_outputs, flat_diff_args, create_graph=True)
grad_input = tree_unflatten(flat_grad_input, spec)
finally:
if grad_input is not None:
grad_input = _undo_create_differentiable(grad_input, level)
if output is not None:
output = _undo_create_differentiable(output, level)
if aux is not None:
aux = _undo_create_differentiable(aux, level)
_grad_decrement_nesting()
if has_aux:
return grad_input, (output, aux)
return grad_input, output
return wrapper
[docs]def grad(func: Callable, argnums: argnums_t = 0, has_aux: bool = False) -> Callable:
"""``grad`` operator helps computing gradients of :attr:`func` with respect to the
input(s) specified by :attr:`argnums`. This operator can be nested to
compute higher-order gradients.
Args:
func (Callable): A Python function that takes one or more arguments.
Must return a single-element Tensor. If specified :attr:`has_aux` equals ``True``,
function can return a tuple of single-element Tensor and other auxiliary objects:
``(output, aux)``.
argnums (int or Tuple[int]): Specifies arguments to compute gradients with respect to.
:attr:`argnums` can be single integer or tuple of integers. Default: 0.
has_aux (bool): Flag indicating that :attr:`func` returns a tensor and other
auxiliary objects: ``(output, aux)``. Default: False.
Returns:
Function to compute gradients with respect to its inputs. By default, the output of
the function is the gradient tensor(s) with respect to the first argument.
If specified :attr:`has_aux` equals ``True``, tuple of gradients and output auxiliary objects
is returned. If :attr:`argnums` is a tuple of integers, a tuple of output gradients with
respect to each :attr:`argnums` value is returned.
Example of using ``grad``:
>>> from functorch import grad
>>> x = torch.randn([])
>>> cos_x = grad(lambda x: torch.sin(x))(x)
>>> assert torch.allclose(cos_x, x.cos())
>>>
>>> # Second-order gradients
>>> neg_sin_x = grad(grad(lambda x: torch.sin(x)))(x)
>>> assert torch.allclose(neg_sin_x, -x.sin())
When composed with ``vmap``, ``grad`` can be used to compute per-sample-gradients:
>>> from functorch import grad
>>> from functorch import vmap
>>> batch_size, feature_size = 3, 5
>>>
>>> def model(weights, feature_vec):
>>> # Very simple linear model with activation
>>> assert feature_vec.dim() == 1
>>> return feature_vec.dot(weights).relu()
>>>
>>> def compute_loss(weights, example, target):
>>> y = model(weights, example)
>>> return ((y - target) ** 2).mean() # MSELoss
>>>
>>> weights = torch.randn(feature_size, requires_grad=True)
>>> examples = torch.randn(batch_size, feature_size)
>>> targets = torch.randn(batch_size)
>>> inputs = (weights, examples, targets)
>>> grad_weight_per_example = vmap(grad(compute_loss), in_dims=(None, 0, 0))(*inputs)
Example of using ``grad`` with :attr:`has_aux` and :attr:`argnums`:
>>> from functorch import grad
>>> def my_loss_func(y, y_pred):
>>> loss_per_sample = (0.5 * y_pred - y) ** 2
>>> loss = loss_per_sample.mean()
>>> return loss, (y_pred, loss_per_sample)
>>>
>>> fn = grad(my_loss_func, argnums=(0, 1), has_aux=True)
>>> y_true = torch.rand(4)
>>> y_preds = torch.rand(4, requires_grad=True)
>>> out = fn(y_true, y_preds)
>>> > output is ((grads w.r.t y_true, grads w.r.t y_preds), (y_pred, loss_per_sample))
.. note::
Using PyTorch ``torch.no_grad`` together with ``grad``.
Case 1: Using ``torch.no_grad`` inside a function:
>>> def f(x):
>>> with torch.no_grad():
>>> c = x ** 2
>>> return x - c
In this case, ``grad(f)(x)`` will respect the inner ``torch.no_grad``.
Case 2: Using ``grad`` inside ``torch.no_grad`` context manager:
>>> with torch.no_grad():
>>> grad(f)(x)
In this case, ``grad`` will respect the inner ``torch.no_grad``, but not the
outer one. This is because ``grad`` is a "function transform": its result
should not depend on the result of a context manager outside of ``f``.
"""
@wraps(func)
def wrapper(*args, **kwargs):
results = grad_and_value(func, argnums, has_aux=has_aux)(*args, **kwargs)
if has_aux:
grad, (_, aux) = results
return grad, aux
grad, _ = results
return grad
return wrapper