Shortcuts

OptimizerInBackwardWrapper

class torchtune.training.OptimizerInBackwardWrapper(optim_map: Dict[str, Optimizer])[source]

A bare-bones class meant for checkpoint save and load for optimizers running in backward. Usage is limited to the following:

Note

This wrapper is only meant to be used for single-device use cases. Distributed use cases such as FSDP, which require specialized optimizer state checkpointing, are not supported.

Parameters:

optim_map (Dict[str, torch.optim.Optimizer]) – Mapping from parameter names to optimizers.

Example

>>> optim_dict = {
>>>     p: config.instantiate(cfg_optimizer, [p])
>>>     for p in self._model.parameters()
>>> }
>>>
>>> # Save checkpoint
>>> ckpt = OptimizerInBackwardWrapper(optim_dict).state_dict()
>>> torch.save("/tmp/optim_ckpt", ckpt)
>>>
>>> # Load checkpoint
>>> placeholder_optim_dict = {
>>>     p: config.instantiate(cfg_optimizer, [p])
>>>     for p in self._model.parameters()
>>> }
>>>
>>> wrapper = OptimInBackwardWrapper(placeholder_optim_dict)
>>>
>>> # load_state_dict expects a dict produced by this class's
>>> # state_dict method.
>>> wrapper.load_state_dict(torch.load("/tmp/optim_ckpt"))
>>> # placeholder_optim_dict now has updated optimizer states.
get_optim_key(key: str) Any[source]

Returns value of key from an arbitrary optimizer running in backward. Note that this assumes all optimizer in backwards have the same value for the key, i.e., are initialized with the same hyperparameters.

load_state_dict(optim_ckpt_map: Dict[str, Any])[source]

Load optimizer states from a state dict produced by this class’s state_dict method.

Parameters:

optim_ckpt_map (Dict[str, Any]) – state dict mapping parameter names to optimizer states.

Raises:

RuntimeError – If the optimizer state dict does not contain all the expected parameters.

state_dict() Dict[str, Any][source]

Returns a state dict mapping parameter names to optimizer states. This state_dict is only loadable by this same class.

Returns:

state dict mapping parameter names to optimizer states.

Return type:

Dict[str, Any]

Docs

Access comprehensive developer documentation for PyTorch

View Docs

Tutorials

Get in-depth tutorials for beginners and advanced developers

View Tutorials

Resources

Find development resources and get your questions answered

View Resources