import torch
import warnings
def detach_variable(inputs):
if isinstance(inputs, tuple):
out = []
for inp in inputs:
x = inp.detach()
x.requires_grad = inp.requires_grad
out.append(x)
return tuple(out)
else:
raise RuntimeError(
"Only tuple of tensors is supported. Got Unsupported input type: ", type(inputs).__name__)
def check_backward_validity(inputs):
if not any(inp.requires_grad for inp in inputs):
warnings.warn("None of the inputs have requires_grad=True. Gradients will be None")
class CheckpointFunction(torch.autograd.Function):
@staticmethod
def forward(ctx, run_function, *args):
check_backward_validity(args)
ctx.run_function = run_function
ctx.save_for_backward(*args)
with torch.no_grad():
outputs = run_function(*args)
return outputs
@staticmethod
def backward(ctx, *args):
if not torch.autograd._is_checkpoint_valid():
raise RuntimeError("Checkpointing is not compatible with .grad(), please use .backward() if possible")
inputs = ctx.saved_tensors
detached_inputs = detach_variable(inputs)
with torch.enable_grad():
outputs = ctx.run_function(*detached_inputs)
if isinstance(outputs, torch.Tensor):
outputs = (outputs,)
torch.autograd.backward(outputs, args)
return (None,) + tuple(inp.grad for inp in detached_inputs)
[docs]def checkpoint(function, *args):
r"""Checkpoint a model or part of the model
Checkpointing works by trading compute for memory. Rather than storing all
intermediate activations of the entire computation graph for computing
backward, the checkpointed part does **not** save intermediate activations,
and instead recomputes them in backward pass. It can be applied on any part
of a model.
Specifically, in the forward pass, :attr:`function` will run in
:func:`torch.no_grad` manner, i.e., not storing the intermediate
activations. Instead, the forward pass saves the inputs tuple and the
:attr:`function` parameter. In the backwards pass, the saved inputs and
:attr:`function` is retreived, and the forward pass is computed on
:attr:`function` again, now tracking the intermediate activations, and then
the gradients are calculated using these activation values.
.. warning::
Checkpointing doesn't work with :func:`torch.autograd.grad`, but only
with :func:`torch.autograd.backward`.
.. warning::
If :attr:`function` invocation during backward does anything different
than the one during forward, e.g., due to some global variable, the
checkpointed version won't be equivalent, and unfortunately it can't be
detected.
.. warning:
At least one of the inputs needs to have :code:`requires_grad=True` if
grads are needed for model inputs, otherwise the checkpointed part of the
model won't have gradients.
Args:
function: describes what to run in the forward pass of the model or
part of the model. It should also know how to handle the inputs
passed as the tuple. For example, in LSTM, if user passes
``(activation, hidden)``, :attr:`function` should correctly use the
first input as ``activation`` and the second input as ``hidden``
args: tuple containing inputs to the :attr:`function`
Returns:
Output of running :attr`function` on *:attr:`args`
"""
return CheckpointFunction.apply(function, *args)
[docs]def checkpoint_sequential(functions, segments, *inputs):
r"""A helper function for checkpointing sequential models.
Sequential models execute a list of modules/functions in order
(sequentially). Therefore, we can divide such a model in various segments
and checkpoint each segment. All segments except the last will run in
:func:`torch.no_grad` manner, i.e., not storing the intermediate
activations. The inputs of each checkpointed segment will be saved for
re-running the segment in the backward pass.
See :func:`~torch.utils.checkpoint.checkpoint` on how checkpointing works.
.. warning::
Checkpointing doesn't work with :func:`torch.autograd.grad`, but only
with :func:`torch.autograd.backward`.
.. warning:
At least one of the inputs needs to have :code:`requires_grad=True` if
grads are needed for model inputs, otherwise the checkpointed part of the
model won't have gradients.
Args:
functions: A :class:`torch.nn.Sequential` or the list of modules or
functions (comprising the model) to run sequentially.
segments: Number of chunks to create in the model
inputs: tuple of Tensors that are inputs to :attr:`functions`
Returns:
Output of running :attr:`functions` sequentially on *:attr:`inputs`
Example:
>>> model = nn.Sequential(...)
>>> input_var = checkpoint_sequential(model, chunks, input_var)
"""
def run_function(start, end, functions):
def forward(*inputs):
input = inputs[0]
for j in range(start, end + 1):
input = functions[j](input)
return input
return forward
if isinstance(functions, torch.nn.Sequential):
functions = list(functions.children())
segment_size = len(functions) // segments
# the last chunk has to be non-volatile
end = -1
for start in range(0, segment_size * (segments - 1), segment_size):
end = start + segment_size - 1
inputs = checkpoint(run_function(start, end, functions), *inputs)
if not isinstance(inputs, tuple):
inputs = (inputs,)
return run_function(end + 1, len(functions) - 1, functions)(*inputs)