torchtnt.utils.optimizer.init_optim_state¶
-
torchtnt.utils.optimizer.
init_optim_state
(optimizer: Optimizer) None ¶ Initialize optimizer states by calling step() with zero grads. This is necessary because some optimizers like AdamW initialize some states in their state_dicts lazily, only after calling step() for the first time. Certain checkpointing solutions may rely on in-place loading, re-using existing tensor allocated memory from the optimizer state dict. This optimization does not work with optimizers that lazily initialize their states, as certain states will not be restored. Calling this function ensures that these states are available in the state dict for in place loading.
Parameters: optimizer – A PyTorch optimizer.