Shortcuts

OptimizerInBackwardWrapper

class torchtune.training.OptimizerInBackwardWrapper(optim_map: Dict[str, Optimizer])[source]

A bare-bones class meant for checkpoint save and load for optimizers running in backward. Usage is limited to the following:

Note

This wrapper is only meant to be used for single-device use cases. Distributed use cases such as FSDP, which require specialized optimizer state checkpointing, are not supported.

Parameters:

optim_map (Dict[str, torch.optim.Optimizer]) – Mapping from parameter names to optimizers.

Example

>>> optim_dict = {
>>>     p: config.instantiate(cfg_optimizer, [p])
>>>     for p in self._model.parameters()
>>> }
>>>
>>> # Save checkpoint
>>> ckpt = OptimizerInBackwardWrapper(optim_dict).state_dict()
>>> torch.save("/tmp/optim_ckpt", ckpt)
>>>
>>> # Load checkpoint
>>> placeholder_optim_dict = {
>>>     p: config.instantiate(cfg_optimizer, [p])
>>>     for p in self._model.parameters()
>>> }
>>>
>>> wrapper = OptimInBackwardWrapper(placeholder_optim_dict)
>>>
>>> # load_state_dict expects a dict produced by this class's
>>> # state_dict method.
>>> wrapper.load_state_dict(torch.load("/tmp/optim_ckpt"))
>>> # placeholder_optim_dict now has updated optimizer states.
get_last_lr() float[source]

Gets the last learning rate from the scheduler if it exists.

Returns:

The last learning rate.

Return type:

float

Raises:

RuntimeError – If the LR scheduler has not been set.

get_optim_key(key: str) Any[source]

Returns value of key from an arbitrary optimizer running in backward. Note that this assumes all optimizer in backwards have the same value for the key, i.e., are initialized with the same hyperparameters.

load_state_dict(optim_ckpt_map: Dict[str, Any])[source]

Load optimizer states from a state dict produced by this class’s state_dict method.

Parameters:

optim_ckpt_map (Dict[str, Any]) – state dict mapping parameter names to optimizer states.

Raises:

RuntimeError – If the optimizer state dict does not contain all the expected parameters.

set_lr_scheduler(lr_scheduler: LRScheduler) None[source]

Sets the learning rate scheduler and modifies its step method to update all optimizers.

Parameters:

lr_scheduler (LRScheduler) – The learning rate scheduler to use.

state_dict() Dict[str, Any][source]

Returns a state dict mapping parameter names to optimizer states. This state_dict is only loadable by this same class.

Returns:

state dict mapping parameter names to optimizer states.

Return type:

Dict[str, Any]

step_lr_scheduler(epoch: int = None)[source]

Steps the learning rate scheduler if it exists.

Parameters:

epoch (int, optional) – The current epoch number. Defaults to None.

Raises:

RuntimeError – If the LR scheduler has not been set.

Docs

Access comprehensive developer documentation for PyTorch

View Docs

Tutorials

Get in-depth tutorials for beginners and advanced developers

View Tutorials

Resources

Find development resources and get your questions answered

View Resources