torch.Tensor.register_post_accumulate_grad_hook¶
- Tensor.register_post_accumulate_grad_hook(hook)[source]¶
Registers a backward hook that runs after grad accumulation.
The hook will be called after all gradients for a tensor have been accumulated, meaning that the .grad field has been updated on that tensor. The post accumulate grad hook is ONLY applicable for leaf tensors (tensors without a .grad_fn field). Registering this hook on a non-leaf tensor will error!
The hook should have the following signature:
hook(param: Tensor) -> None
Note that, unlike other autograd hooks, this hook operates on the tensor that requires grad and not the grad itself. The hook can in-place modify and access its Tensor argument, including its .grad field.
This function returns a handle with a method
handle.remove()
that removes the hook from the module.Note
See Backward Hooks execution for more information on how when this hook is executed, and how its execution is ordered relative to other hooks. Since this hook runs during the backward pass, it will run in no_grad mode (unless create_graph is True). You can use torch.enable_grad() to re-enable autograd within the hook if you need it.
Example:
>>> v = torch.tensor([0., 0., 0.], requires_grad=True) >>> lr = 0.01 >>> # simulate a simple SGD update >>> h = v.register_post_accumulate_grad_hook(lambda p: p.add_(p.grad, alpha=-lr)) >>> v.backward(torch.tensor([1., 2., 3.])) >>> v tensor([-0.0100, -0.0200, -0.0300], requires_grad=True) >>> h.remove() # removes the hook