Source code for torch.autograd.graph
# mypy: allow-untyped-defs
import abc
import collections
import contextlib
import functools
import logging
import threading
import weakref
from collections import defaultdict, namedtuple
from typing import (
Any,
Callable,
cast,
Deque,
Dict,
List,
Optional,
Sequence,
Set,
Tuple,
Union,
)
import torch
from torch.autograd.variable import Variable
from torch.utils._python_dispatch import TorchDispatchMode
from torch.utils.hooks import RemovableHandle
log = logging.getLogger(__name__)
__all__ = [
"saved_tensors_hooks",
"save_on_cpu",
"disable_saved_tensors_hooks",
"register_multi_grad_hook",
"allow_mutation_on_saved_tensors",
"Node",
"GradientEdge",
"get_gradient_edge",
"increment_version",
]
class Node(abc.ABC):
[docs] @abc.abstractmethod
def name(self) -> str:
r"""Return the name.
Example::
>>> import torch
>>> a = torch.tensor([0., 0., 0.], requires_grad=True)
>>> b = a.clone()
>>> assert isinstance(b.grad_fn, torch.autograd.graph.Node)
>>> print(b.grad_fn.name())
CloneBackward0
"""
...
@property
@abc.abstractmethod
def next_functions(self) -> Tuple[Tuple[Optional["Node"], int], ...]:
...
@abc.abstractmethod
def _register_hook_dict(self, tensor: torch.Tensor) -> None:
...
[docs] @abc.abstractmethod
def register_hook(self, fn: Callable[..., Any]) -> RemovableHandle:
r"""Register a backward hook.
The hook will be called every time a gradient with respect to the
Node is computed. The hook should have the following signature::
hook(grad_inputs: Tuple[Tensor], grad_outputs: Tuple[Tensor]) -> Tuple[Tensor] or None
The hook should not modify its argument, but it can optionally return
a new gradient which will be used in place of :attr:`grad_inputs`.
This function returns a handle with a method ``handle.remove()``
that removes the hook from the module.
.. note::
See :ref:`backward-hooks-execution` for more information on how when this hook
is executed, and how its execution is ordered relative to other hooks.
Example::
>>> import torch
>>> a = torch.tensor([0., 0., 0.], requires_grad=True)
>>> b = a.clone()
>>> assert isinstance(b.grad_fn, torch.autograd.graph.Node)
>>> handle = b.grad_fn.register_hook(lambda gI, gO: (gO[0] * 2,))
>>> b.sum().backward(retain_graph=True)
>>> print(a.grad)
tensor([2., 2., 2.])
>>> handle.remove() # Removes the hook
>>> a.grad = None
>>> b.sum().backward(retain_graph=True)
>>> print(a.grad)
tensor([1., 1., 1.])
"""
...
[docs] @abc.abstractmethod
def register_prehook(self, fn: Callable[..., Any]) -> RemovableHandle:
r"""Register a backward pre-hook.
The hook will be called every time a gradient with respect to the
Node is computed. The hook should have the following signature::
hook(grad_outputs: Tuple[Tensor]) -> Tuple[Tensor] or None
The hook should not modify its argument, but it can optionally return
a new gradient which will be used in place of :attr:`grad_outputs`.
This function returns a handle with a method ``handle.remove()``
that removes the hook from the module.
.. note::
See :ref:`backward-hooks-execution` for more information on how when this hook
is executed, and how its execution is ordered relative to other hooks.
Example::
>>> a = torch.tensor([0., 0., 0.], requires_grad=True)
>>> b = a.clone()
>>> assert isinstance(b.grad_fn, torch.autograd.graph.Node)
>>> handle = b.grad_fn.register_prehook(lambda gI: (gI[0] * 2,))
>>> b.sum().backward(retain_graph=True)
>>> print(a.grad)
tensor([2., 2., 2.])
>>> handle.remove()
>>> a.grad = None
>>> b.sum().backward(retain_graph=True)
>>> print(a.grad)
tensor([1., 1., 1.])
"""
...
@classmethod
def __subclasshook__(cls, C):
if cls is Node:
if (
C is not None and C is getattr(torch._C._functions, C.__name__, None)
) or issubclass(C, torch.autograd.function.BackwardCFunction):
return True
return NotImplemented
def _get_grad_fn_or_grad_acc(t):
if t.requires_grad and t.grad_fn is None:
with torch.enable_grad():
return t.view_as(t).grad_fn.next_functions[0][0]
else:
return t.grad_fn
GradientEdge = namedtuple("GradientEdge", ("node output_nr"))
GradientEdge.__doc__ = """\
Object representing a given gradient edge within the autograd graph.
To get the gradient edge where a given Tensor gradient will be computed,
you can do ``edge = autograd.graph.get_gradient_edge(tensor)``.
"""
[docs]def get_gradient_edge(tensor):
"""Get the gradient edge for computing the gradient of the given Tensor.
In particular, it is equivalent to call
``g = autograd.grad(loss, input)`` and ``g = autograd.grad(loss, get_gradient_edge(input))``.
"""
if not tensor.requires_grad:
raise RuntimeError(
"It is not possible to get the gradient edge for a Tensor that does not require gradients"
)
grad_fn = _get_grad_fn_or_grad_acc(tensor)
# Note that output_nr default to 0 which is the right value
# for the AccumulateGrad node.
return GradientEdge(grad_fn, tensor.output_nr)
[docs]def increment_version(tensor):
"""Update autograd metadata tracking whether the given Tensor was modified in place.
This is to enable more accurate error checking within the autograd engine.
It is already done automatically by PyTorch functions and within custom Function
when mark_dirty() is called appropriately so you only need to call this explicitly
if you are doing inplace operation on the Tensor data in a way that Pytorch doesn't
know about. For example a custom kernel that reads the Tensor data_ptr and modifies
the memory inplace based on this pointer.
Note that incrementing the version counter multiple times for a single inplace operation
is not problematic.
"""
torch._C._increment_version(tensor)
[docs]class saved_tensors_hooks:
"""Context-manager that sets a pair of pack / unpack hooks for saved tensors.
Use this context-manager to define how intermediary results of an operation
should be packed before saving, and unpacked on retrieval.
In that context, the ``pack_hook`` function will be called everytime an
operation saves a tensor for backward (this includes intermediary results
saved using
:func:`~torch.autograd.function._ContextMethodMixin.save_for_backward` but
also those recorded by a PyTorch-defined operation). The output of
``pack_hook`` is then stored in the computation graph instead of the
original tensor.
The ``unpack_hook`` is called when the saved tensor needs to be accessed,
namely when executing :func:`torch.Tensor.backward()` or
:func:`torch.autograd.grad()`. It takes as argument the *packed* object
returned by ``pack_hook`` and should return a tensor which has the same
content as the original tensor (passed as input to the corresponding
``pack_hook``).
The hooks should have the following signatures:
pack_hook(tensor: Tensor) -> Any
unpack_hook(Any) -> Tensor
where the return value of ``pack_hook`` is a valid input to ``unpack_hook``.
In general, you want ``unpack_hook(pack_hook(t))`` to be equal to ``t`` in terms
of value, size, dtype and device.
Example::
>>> # xdoctest: +REQUIRES(env:TORCH_DOCTEST_AUTOGRAD)
>>> def pack_hook(x):
... print("Packing", x)
... return x
>>>
>>> def unpack_hook(x):
... print("Unpacking", x)
... return x
>>>
>>> a = torch.ones(5, requires_grad=True)
>>> b = torch.ones(5, requires_grad=True) * 2
>>> with torch.autograd.graph.saved_tensors_hooks(pack_hook, unpack_hook):
... y = a * b
Packing tensor([1., 1., 1., 1., 1.], requires_grad=True)
Packing tensor([2., 2., 2., 2., 2.], grad_fn=<MulBackward0>)
>>> y.sum().backward()
Unpacking tensor([1., 1., 1., 1., 1.], requires_grad=True)
Unpacking tensor([2., 2., 2., 2., 2.], grad_fn=<MulBackward0>)
.. warning ::
Performing an inplace operation on the input to either hooks may lead
to undefined behavior.
.. warning ::
Only one pair of hooks is allowed at a time. When recursively nesting this
context-manager, only the inner-most pair of hooks will be applied.
"""
def __init__(
self,
pack_hook: Callable[[torch.Tensor], Any],
unpack_hook: Callable[[Any], torch.Tensor],
):
self.pack_hook = pack_hook
self.unpack_hook = unpack_hook
def __enter__(self):
torch._C._autograd._push_saved_tensors_default_hooks(
self.pack_hook, self.unpack_hook
)
def __exit__(self, *args: object):
torch._C._autograd._pop_saved_tensors_default_hooks()
[docs]class save_on_cpu(saved_tensors_hooks):
"""Context manager under which tensors saved by the forward pass will be stored on cpu, then retrieved for backward.
When performing operations within this context manager, intermediary
results saved in the graph during the forward pass will be moved to CPU,
then copied back to the original device when needed for the backward pass.
If the graph was already on CPU, no tensor copy is performed.
Use this context-manager to trade compute for GPU memory usage (e.g.
when your model doesn't fit in GPU memory during training).
Args:
pin_memory (bool): If ``True`` tensors will be saved to CPU pinned memory
during packing and copied to GPU asynchronously during unpacking.
Defaults to ``False``.
Also see :ref:`cuda-memory-pinning`.
Example::
>>> # xdoctest: +REQUIRES(env:TORCH_DOCTEST_CUDA)
>>> # xdoctest: +REQUIRES(env:TORCH_DOCTEST_AUTOGRAD)
>>> a = torch.randn(5, requires_grad=True, device="cuda")
>>> b = torch.randn(5, requires_grad=True, device="cuda")
>>> c = torch.randn(5, requires_grad=True, device="cuda")
>>>
>>> def f(a, b, c):
... prod_1 = a * b # a and b are saved on GPU
... with torch.autograd.graph.save_on_cpu():
... prod_2 = prod_1 * c # prod_1 and c are saved on CPU
... y = prod_2 * a # prod_2 and a are saved on GPU
... return y
>>>
>>> y = f(a, b, c)
>>> del a, b, c # for illustration only
>>> # the content of a, b, and prod_2 are still alive on GPU
>>> # the content of prod_1 and c only live on CPU
>>> y.sum().backward() # all CPU tensors are moved back to GPU, for backward
>>> # all intermediary tensors are released (deleted) after the call to backward
"""
def __init__(self, pin_memory=False, device_type="cuda"):
device_module = getattr(torch, device_type, torch.cuda)
def pack_to_cpu(tensor):
if not pin_memory:
return (tensor.device, tensor.cpu())
packed = torch.empty(
tensor.size(),
dtype=tensor.dtype,
layout=tensor.layout,
pin_memory=(device_module.is_available() and not tensor.is_sparse),
)
packed.copy_(tensor)
return (tensor.device, packed)
def unpack_from_cpu(packed):
device, tensor = packed
return tensor.to(device, non_blocking=pin_memory)
super().__init__(pack_to_cpu, unpack_from_cpu)
[docs]@contextlib.contextmanager
def disable_saved_tensors_hooks(error_message):
"""Context-manager that disables the saved tensors default hooks feature.
Useful for if you are creating a feature that does not work with saved
tensors default hooks.
Args:
error_message (str): When saved tensors default hooks are used when they
have been are disabled, a RuntimeError with this
error message gets raised.
Example::
>>> # xdoctest: +SKIP(failing)
>>> message = "saved tensors default hooks are disabled"
>>> with torch.autograd.graph.disable_saved_tensors_hooks(message):
... # Raises RuntimeError: saved tensors default hooks are disabled
... with torch.autograd.graph.save_on_cpu():
... pass
"""
try:
maybe_prev_message = (
torch._C._autograd._saved_tensors_hooks_get_disabled_error_message()
)
torch._C._autograd._saved_tensors_hooks_disable(error_message)
yield
finally:
# See NOTE: [disabled_error_message invariant]
if maybe_prev_message is None:
torch._C._autograd._saved_tensors_hooks_enable()
else:
torch._C._autograd._saved_tensors_hooks_disable(maybe_prev_message)
class _MultiHandle(RemovableHandle):
handles: Tuple[RemovableHandle, ...]
def __init__(self, handles: Tuple[RemovableHandle, ...]):
self.handles = handles
def remove(self):
for handle in self.handles:
handle.remove()
def __getstate__(self):
return self.handles
def __setstate__(self, state):
self.handles = state
[docs]def register_multi_grad_hook(
tensors: Sequence[torch.Tensor],
fn: Union[
Callable[[Sequence[Optional[torch.Tensor]]], None],
Callable[[torch.Tensor], None],
],
*,
mode: str = "all",
):
r"""Register a multi-grad backward hook.
There are two supported modes: ``"all"`` and ``"any"``.
Under the ``"all"`` mode, the hook will be called after gradients with respect to every tensor in
:attr:`tensors` have been computed. If a tensor is in :attr:`tensors` but
is not part of the graph, or if a tensor is not needed to compute the gradients
for any ``inputs`` specified for the current ``.backward()`` or ``.grad()`` call,
this tensor will be ignored and the hook will not wait for its gradient to be
computed.
After every non-ignored tensor's gradient has been computed, :attr:`fn` will be
called with those gradients. ``None`` will be passed for tensors that did not
have their gradients computed.
Under the ``"any"`` mode, the hook will be called after the first gradient
with respect to a tensor in :attr:`tensors` has been computed. The hook
will be called with that gradient as its argument.
The hook should not modify its arguments.
This function returns a handle with a method ``handle.remove()`` that removes the hook.
.. note::
See :ref:`backward-hooks-execution` for more information on how when this hook
is executed, and how its execution is ordered relative to other hooks.
Example::
>>> import torch
>>>
>>> a = torch.rand(2, 3, requires_grad=True)
>>> b = torch.rand(2, 3, requires_grad=True)
>>> c = a * b
>>> d = a * b
>>>
>>> def fn(grads):
... print([g is not None for g in grads])
...
>>> torch.autograd.graph.register_multi_grad_hook((a, b, c, d), fn)
>>>
>>> c.sum().backward(retain_graph=True)
[True, True, True, False]
>>> c.sum().backward(inputs=(a,), retain_graph=True)
[True, False, True, False]
>>>
"""
supported_modes = ("all", "any")
if mode not in supported_modes:
raise ValueError(f"Expects mode to be one of {supported_modes} but got {mode}")
if mode == "all":
count: Dict[int, int] = dict()
nb_calls = None
buffer: Dict[int, List[Optional[torch.Tensor]]] = dict()
grad_fns = list(map(_get_grad_fn_or_grad_acc, tensors))
len_tensors = len(tensors)
def get_inner_hook(idx):
def inner_hook(grad: torch.Tensor):
nonlocal count, nb_calls, buffer, fn
id = torch._C._current_graph_task_id()
assert (
id != -1
), "expected this hook to be called inside a backward call"
count[id] = count.get(id, 0)
buffer[id] = buffer.get(id, [None] * len_tensors)
if count[id] == 0:
# On the first call, compute the actual nb_calls and buffer
nb_calls = sum(torch._C._will_engine_execute_node(g) for g in grad_fns) # type: ignore[attr-defined]
buffer[id][idx] = grad
count[id] += 1
if count[id] == nb_calls:
fn = cast(Callable[[Sequence[Optional[torch.Tensor]]], None], fn)
fn(buffer[id])
del count[id]
del buffer[id]
return inner_hook
handles: Tuple[RemovableHandle] = tuple(
t.register_hook(get_inner_hook(i)) for i, t in enumerate(tensors)
)
elif mode == "any":
fn = cast(Callable[[torch.Tensor], None], fn)
lock = threading.Lock()
ran_hook: Dict[int, bool] = defaultdict(bool)
@functools.wraps(fn)
def wrapped_fn(grad: torch.Tensor):
nonlocal ran_hook
id = torch._C._current_graph_task_id()
assert id != -1, "expected this hook to be called inside a backward call"
with lock:
prev, ran_hook[id] = ran_hook[id], True
if prev:
return
fn(grad)
handles = tuple(
tensor.register_hook(wrapped_fn)
for tensor in tensors
if tensor.requires_grad
)
return _MultiHandle(handles) # type: ignore[possibly-undefined]
# NOTE [Allow mutation on tensors saved for backward]
#
# 1. Tensor gets saved for backward
# - remember the python object id and the version of the tensor
# - remember aliasing information (data_ptr of base + version)
# - save the original so we control its lifetime
# 2. Any time a tensor gets in-placed
# - for each tensor aliased to it:
# - check using its object id and version to see if it has been saved
# - if it has been saved, clone it
# - delete the reference to the original
# 3. during backward
# - if the clone exists, the tensor must've been modified in-place
_allow_mutation_on_saved_tensors_enabled = False
def _get_tid(t) -> Tuple[int, int, int]:
# FIXME: This is almost definitely a bug.
if isinstance(
t,
(
torch._subclasses.fake_tensor.FakeTensor,
torch._subclasses.functional_tensor.FunctionalTensor,
),
):
data_ptr = 0
else:
data_ptr = t.data_ptr()
return (id(t), data_ptr, t._version)
def _get_sid(t) -> Tuple[int, int]:
# FIXME: This is almost definitely a bug.
if isinstance(
t,
(
torch._subclasses.fake_tensor.FakeTensor,
torch._subclasses.functional_tensor.FunctionalTensor,
),
):
data_ptr = 0
else:
data_ptr = t.data_ptr()
return (data_ptr, t._version)
class _Handle:
pass
class _swap_with_cloned(saved_tensors_hooks):
def __init__(self, ctx):
def pack_hook(t):
tid = _get_tid(t)
sid = _get_sid(t)
# Tensors saved for backward have an entry in _tid_to_weakhandle
handle: Optional[_Handle] = None
# Save aliasing information
ctx.sid_to_tid[sid].add(tid)
# NB: The same tensor (of the same version) can be saved multiple times
if tid not in ctx.tid_to_weakhandle:
handle = _Handle()
ctx.tid_to_weakhandle[tid] = handle
ctx.original[handle] = t
else:
# Store an additional strong reference to the handle
handle = ctx.tid_to_weakhandle[tid]
return handle
def unpack_hook(tup):
handle = tup
error_msg = (
"Trying to backward outside of the 'allow_mutation_on_saved_tensors' context"
"in which the graph was originally recorded."
)
assert _allow_mutation_on_saved_tensors_enabled, error_msg
if handle in ctx.cloned:
res = ctx.cloned[handle]
else:
assert handle in ctx.original, error_msg
res = ctx.original[handle]
return res
super().__init__(pack_hook, unpack_hook)
class _CloneArgBeforeMutateMode(TorchDispatchMode):
def __init__(self, ctx):
self.ctx = ctx
def __torch_dispatch__(self, func, types, args=(), kwargs=None):
kwargs = kwargs or {}
for idx, arg in enumerate(func._schema.arguments):
if arg.alias_info is not None and arg.alias_info.is_write:
t = kwargs["out"] if arg.is_out else args[idx]
tid = _get_tid(t)
sid = _get_sid(t)
ctx = self.ctx
if sid in ctx.sid_to_tid:
for tid in ctx.sid_to_tid[sid]:
if tid not in ctx.tid_to_weakhandle:
# We know that if tid is in sid_to_tid, then it must also be in
# tid_to_weakhandle. However, it is possible for the tensor to be
# saved at one point, but cleared by backward before it is modified
# in-place. Consider the following example:
#
# >>> a = torch.randn(2, 3, requires_grad=True).clone()
# >>> out = (a**2).sum()
# >>> out.backward()
# >>> a.sin_()
continue
handle = ctx.tid_to_weakhandle[tid]
if handle in ctx.cloned:
# The same exact tensor has been cloned already
continue
ctx.cloned[handle] = ctx.original[handle].clone()
del ctx.original[handle]
rs = func(*args, **kwargs)
return rs
class _AllowMutationOnSavedContext:
def __init__(self):
self.cloned: weakref.WeakKeyDictionary = weakref.WeakKeyDictionary()
self.original: weakref.WeakKeyDictionary = weakref.WeakKeyDictionary()
self.tid_to_weakhandle: weakref.WeakValueDictionary = (
weakref.WeakValueDictionary()
)
self.sid_to_tid: Dict[Tuple[int, int], Set[Tuple[int, int, int]]] = defaultdict(
set
)
def clear(self):
self.cloned.clear()
self.original.clear()
self.tid_to_weakhandle.clear()
self.sid_to_tid.clear()
[docs]@contextlib.contextmanager
def allow_mutation_on_saved_tensors():
"""Context manager under which mutating tensors saved for backward is allowed.
Under this context manager, tensors saved for backward are cloned on mutation,
so the original version can still be used during backward. Normally, mutating a tensor
saved for backward will result in an error raised when it's used during backward.
To ensure the correct behavior, both the forward and backward should be run under
the same context manager.
returns:
An _AllowMutationOnSavedContext object storing the state managed by this
context manager. This object can be useful for debugging purposes. The state
managed by the context manager is automatically cleared upon exiting.
Example::
>>> import torch
>>> with torch.autograd.graph.allow_mutation_on_saved_tensors():
... # forward
... a = torch.ones(2, 3, requires_grad=True)
... b = a.clone()
... out = (b**2).sum()
... b.sin_()
... # backward
... out.sum().backward()
...
tensor([[0.8415, 0.8415, 0.8415],
[0.8415, 0.8415, 0.8415]], grad_fn=<SinBackward0>)
"""
global _allow_mutation_on_saved_tensors_enabled
ctx = _AllowMutationOnSavedContext()
with _swap_with_cloned(ctx), _CloneArgBeforeMutateMode(ctx):
try:
if _allow_mutation_on_saved_tensors_enabled:
raise RuntimeError(
"allow_mutation_on_saved_tensors contexts cannot be nested"
)
_allow_mutation_on_saved_tensors_enabled = True
yield ctx
finally:
ctx.clear()
_allow_mutation_on_saved_tensors_enabled = False
def _register_logging_hooks_on_whole_graph(t_outputs: List[torch.Tensor]):
grad_fns = list(map(_get_grad_fn_or_grad_acc, t_outputs))
def iter_graph(roots):
if not roots:
return
seen = set()
q: Deque = collections.deque()
for node in roots:
if node is not None:
seen.add(node)
q.append(node)
while q:
node = q.popleft()
for fn, _idx in node.next_functions:
if fn in seen or fn is None:
continue
seen.add(fn)
q.append(fn)
yield node
def fmt(t):
# Avoid circular import
from torch.testing._internal.common_utils import dtype_abbrs
if t is None:
return "None"
return f"{dtype_abbrs[t.dtype]}[{', '.join(map(str, t.shape))}]"
def prehook(grad_outputs):
node = torch._C._current_autograd_node()
grad_outputs_str = f"[{','.join(fmt(t) for t in grad_outputs)}]"
log_str = f"Executing: {node} with grad_outputs: {grad_outputs_str}"
log.debug(log_str)
handles = []
for node in iter_graph(grad_fns):
handles.append(node.register_prehook(prehook))
def unregister_hooks():
for handle in handles:
handle.remove()
return unregister_hooks
def _engine_run_backward(t_outputs, *args, **kwargs):
attach_logging_hooks = log.getEffectiveLevel() <= logging.DEBUG
if attach_logging_hooks:
unregister_hooks = _register_logging_hooks_on_whole_graph(t_outputs)
try:
return Variable._execution_engine.run_backward( # Calls into the C++ engine to run the backward pass
t_outputs, *args, **kwargs
) # Calls into the C++ engine to run the backward pass
finally:
if attach_logging_hooks:
unregister_hooks() # type: ignore[possibly-undefined]