Optimizers¶
This module implements an optimizer wrapper that works with the Manager to provide fault tolerance.
- class torchft.optim.OptimizerWrapper(manager: Manager, optim: Optimizer)[source]¶
Bases:
Optimizer
This wraps any provided optimizer and in conjunction with the manager will provide fault tolerance.
zero_grad() must be called at the start of the forwards pass and step() must be called at the end of the backwards pass.
Depending on the state of the manager, the optimizer will either commit the gradients to the wrapped optimizer or ignore them.
- add_param_group(param_group: object) None [source]¶
Add a param group to the
Optimizer
s param_groups.This can be useful when fine tuning a pre-trained network as frozen layers can be made trainable and added to the
Optimizer
as training progresses.- Parameters:
param_group (dict) – Specifies what Tensors should be optimized along with group specific optimization options.
- load_state_dict(state_dict: object) None [source]¶
Load the optimizer state.
- Parameters:
state_dict (dict) – optimizer state. Should be an object returned from a call to
state_dict()
.
- state_dict() object [source]¶
Return the state of the optimizer as a
dict
.It contains two entries:
state
: a Dict holding current optimization state. Its contentdiffers between optimizer classes, but some common characteristics hold. For example, state is saved per parameter, and the parameter itself is NOT saved.
state
is a Dictionary mapping parameter ids to a Dict with state corresponding to each parameter.
param_groups
: a List containing all parameter groups where eachparameter group is a Dict. Each parameter group contains metadata specific to the optimizer, such as learning rate and weight decay, as well as a List of parameter IDs of the parameters in the group.
NOTE: The parameter IDs may look like indices but they are just IDs associating state with param_group. When loading from a state_dict, the optimizer will zip the param_group
params
(int IDs) and the optimizerparam_groups
(actualnn.Parameter
s) in order to match state WITHOUT additional verification.A returned state dict might look something like:
{ 'state': { 0: {'momentum_buffer': tensor(...), ...}, 1: {'momentum_buffer': tensor(...), ...}, 2: {'momentum_buffer': tensor(...), ...}, 3: {'momentum_buffer': tensor(...), ...} }, 'param_groups': [ { 'lr': 0.01, 'weight_decay': 0, ... 'params': [0] }, { 'lr': 0.001, 'weight_decay': 0.5, ... 'params': [1, 2, 3] } ] }
- step(closure: Optional[object] = None) None [source]¶
Perform a single optimization step to update parameter.
- Parameters:
closure (Callable) – A closure that reevaluates the model and returns the loss. Optional for most optimizers.
Note
Unless otherwise specified, this function should not modify the
.grad
field of the parameters.
- zero_grad(set_to_none: bool = True) None [source]¶
Reset the gradients of all optimized
torch.Tensor
s.- Parameters:
set_to_none (bool) – instead of setting to zero, set the grads to None. This will in general have lower memory footprint, and can modestly improve performance. However, it changes certain behaviors. For example: 1. When the user tries to access a gradient and perform manual ops on it, a None attribute or a Tensor full of 0s will behave differently. 2. If the user requests
zero_grad(set_to_none=True)
followed by a backward pass,.grad
s are guaranteed to be None for params that did not receive a gradient. 3.torch.optim
optimizers have a different behavior if the gradient is 0 or None (in one case it does the step with a gradient of 0 and in the other it skips the step altogether).