register_optim_in_bwd_hooks¶
- torchtune.utils.register_optim_in_bwd_hooks(model: Module, optim_dict: Dict[Parameter, Optimizer]) None [source]¶
Register hooks for optimizer step running in backward.
When fusing the optimizer step into backward, we need to call
.step()
on the optimizer for a given parameter as soon as its gradient is ready. This utility registers post-accumulate-grad hooks on all parameters in the model to achieve this.- Parameters:
model (torch.nn.Module) – Model whose parameters will be optimized. Note that currently hooks for ALL parameters in the model will be registered.
optim_dict (Dict[torch.nn.Parameter, torch.optim.Optimizer]) – Mapping from parameters to optimizers.