OptimizerInBackwardWrapper¶
- class torchtune.utils.OptimizerInBackwardWrapper(optim_map: Dict[str, Optimizer])[source]¶
A bare-bones class meant for checkpoint save and load for optimizers running in backward. Usage is limited to the following:
Note
This wrapper is only meant to be used for single-device use cases. Distributed use cases such as FSDP, which require specialized optimizer state checkpointing, are not supported.
- Parameters:
optim_map (Dict[str, torch.optim.Optimizer]) – Mapping from parameter names to optimizers.
Example
>>> optim_dict = { >>> p: config.instantiate(cfg_optimizer, [p]) >>> for p in self._model.parameters() >>> } >>> >>> # Save checkpoint >>> ckpt = OptimizerInBackwardWrapper(optim_dict).state_dict() >>> torch.save("/tmp/optim_ckpt", ckpt) >>> >>> # Load checkpoint >>> placeholder_optim_dict = { >>> p: config.instantiate(cfg_optimizer, [p]) >>> for p in self._model.parameters() >>> } >>> >>> wrapper = OptimInBackwardWrapper(placeholder_optim_dict) >>> >>> # load_state_dict expects a dict produced by this class's >>> # state_dict method. >>> wrapper.load_state_dict(torch.load("/tmp/optim_ckpt")) >>> # placeholder_optim_dict now has updated optimizer states.
- get_optim_key(key: str) Any [source]¶
Returns value of key from an arbitrary optimizer running in backward. Note that this assumes all optimizer in backwards have the same value for the key, i.e., are initialized with the same hyperparameters.
- load_state_dict(optim_ckpt_map: Dict[str, Any])[source]¶
Load optimizer states from a state dict produced by this class’s state_dict method.
- Parameters:
optim_ckpt_map (Dict[str, Any]) – state dict mapping parameter names to optimizer states.
- Raises:
RuntimeError – If the optimizer state dict does not contain all the expected parameters.