Shortcuts

torch.Tensor.register_post_accumulate_grad_hook

Tensor.register_post_accumulate_grad_hook(hook)[source][source]

Registers a backward hook that runs after grad accumulation.

The hook will be called after all gradients for a tensor have been accumulated, meaning that the .grad field has been updated on that tensor. The post accumulate grad hook is ONLY applicable for leaf tensors (tensors without a .grad_fn field). Registering this hook on a non-leaf tensor will error!

The hook should have the following signature:

hook(param: Tensor) -> None

Note that, unlike other autograd hooks, this hook operates on the tensor that requires grad and not the grad itself. The hook can in-place modify and access its Tensor argument, including its .grad field.

This function returns a handle with a method handle.remove() that removes the hook from the module.

Note

See Backward Hooks execution for more information on how when this hook is executed, and how its execution is ordered relative to other hooks. Since this hook runs during the backward pass, it will run in no_grad mode (unless create_graph is True). You can use torch.enable_grad() to re-enable autograd within the hook if you need it.

Example:

>>> v = torch.tensor([0., 0., 0.], requires_grad=True)
>>> lr = 0.01
>>> # simulate a simple SGD update
>>> h = v.register_post_accumulate_grad_hook(lambda p: p.add_(p.grad, alpha=-lr))
>>> v.backward(torch.tensor([1., 2., 3.]))
>>> v
tensor([-0.0100, -0.0200, -0.0300], requires_grad=True)

>>> h.remove()  # removes the hook

Docs

Access comprehensive developer documentation for PyTorch

View Docs

Tutorials

Get in-depth tutorials for beginners and advanced developers

View Tutorials

Resources

Find development resources and get your questions answered

View Resources