Shortcuts

Adafactor

class torch.optim.Adafactor(params, lr=0.01, beta2_decay=-0.8, eps=(None, 0.001), d=1.0, weight_decay=0.0, *, foreach=None, maximize=False)

Implements Adafactor algorithm.

input:γ(lr),τ(β2 decay),θ0(params),f(θ)(objective),ϵ1,ϵ2 (epsilons),d(clipping threshold),λ(weight decay),maximizeinitialize:R00 (second moment row factor),C00 (second moment col factor),V^00 (second moment for vectors)fort=1todoifmaximize:Gtθft(θt1)elseGtθft(θt1)β^2t1tτρtmin(lr,1t)αtmax(ϵ2,RMS(θt1))ρtθtθt1γλθt1ifdim(Gt)>1:Rtβ^2tRt1+(1β^2t)(GtGt)1mCtβ^2tCt1+(1β^2t)1n(GtGt)V^tRtCtmax(1nRt,ϵ1)elseV^tβ^2tV^t1+(1β^2t)(GtGt)UtGtmax(V^t,ϵ1)U^tUtmax(1,RMS(Ut)d)θtθt1αtU^treturnθt\begin{aligned} &\rule{110mm}{0.4pt} \\ &\textbf{input} : \gamma \text{(lr)}, \: \tau \text{(}\beta_2\text{ decay)}, \: \theta_0 \text{(params)}, \: f(\theta) \text{(objective)}, \\ &\hspace{15mm} \: \epsilon_1, \epsilon_2 \text{ (epsilons)}, \: d \text{(clipping threshold)}, \\ &\hspace{15mm} \: \lambda \text{(weight decay)}, \: \textit{maximize} \\ &\textbf{initialize} : \: R_0 \leftarrow 0 \text{ (second moment row factor)}, \\ &\hspace{23mm} \: C_0 \leftarrow 0 \text{ (second moment col factor)}, \\ &\hspace{23mm} \: \widehat{V}_0 \leftarrow 0 \text{ (second moment for vectors)} \\[-1.ex] &\rule{110mm}{0.4pt} \\ &\textbf{for} \: t=1 \: \textbf{to} \: \ldots \: \textbf{do} \\ &\hspace{5mm}\textbf{if} \: \textit{maximize}: \\ &\hspace{10mm}G_t \leftarrow -\nabla_{\theta} f_t (\theta_{t-1}) \\ &\hspace{5mm}\textbf{else} \\ &\hspace{10mm}G_t \leftarrow \nabla_{\theta} f_t (\theta_{t-1}) \\ &\hspace{5mm}\widehat{\beta}_{2_t} \leftarrow 1 - t^{\tau} \\ &\hspace{5mm}\rho_t \leftarrow min(lr, \frac{1}{\sqrt{t}}) \\ &\hspace{5mm}\alpha_t \leftarrow max(\epsilon_2, \text{RMS}(\theta_{t-1}))\rho_t \\ &\hspace{5mm}\theta_t \leftarrow \theta_{t-1} - \gamma \lambda \theta_{t-1} \\ &\hspace{5mm}\textbf{if} \: \text{dim}(G_t) > 1: \\ &\hspace{10mm}R_t \leftarrow \widehat{\beta}_{2_t}R_{t-1}+ (1-\widehat{\beta}_{2_t})(G_t \odot G_t) \cdot 1_m \\ &\hspace{10mm}C_t \leftarrow \widehat{\beta}_{2_t}C_{t-1}+ (1-\widehat{\beta}_{2_t}) 1^\top_n \cdot (G_t \odot G_t) \\ &\hspace{10mm}\widehat{V}_t \leftarrow \frac{R_t \cdot C_t}{max(1^\top_n \cdot R_t, \epsilon_1)} \\ &\hspace{5mm}\textbf{else} \\ &\hspace{10mm}\widehat{V}_t \leftarrow \widehat{\beta}_{2_t}\widehat{V}_{t-1}+ (1-\widehat{\beta}_{2_t}) \cdot (G_t \odot G_t) \\ &\hspace{5mm}U_t \leftarrow \frac{G_t}{max(\sqrt{\widehat{V}_t}, \epsilon_1)} \\ &\hspace{5mm}\widehat{U}_t \leftarrow \frac{U_t}{max(1, \frac{\text{RMS}(U_t)}{d})} \\ &\hspace{5mm}\theta_t \leftarrow \theta_{t-1} - \alpha_t \widehat{U}_t \\ &\rule{110mm}{0.4pt} \\[-1.ex] &\bf{return} \: \theta_t \\[-1.ex] &\rule{110mm}{0.4pt} \\[-1.ex] \end{aligned}

For further details regarding the algorithm we refer to Adafactor: Adaptive Learning Rates with Sublinear Memory Cost.

Parameters
  • params (iterable) – iterable of parameters to optimize or dicts defining parameter groups

  • lr (float, Tensor, optional) – unlike other optimizers, Adafactor does not require a learning rate, and Shazeer, Noam, and Mitchell Stern do not use lr at all. Deviating from the paper, this implementation uses lr for applying weight decay and as the maximum value for relative step size rho_t. Note that in the paper, a constant of 0.01 is used as the maximum value for relative step size, and so we set 0.01 as the default value. (default: 1e-2)

  • beta2_decay (float, optional) – the decay rate of beta2. beta2 standardly refers to the coefficient used for computing the running average of the gradient squared. (default: -0.8)

  • eps (Tuple[float, float], optional) – epsilon1 is the term added to the denominator of the update calculation to improve numerical stability. This use of epsilon1 deviates from the algorithm written in the paper! See note below for more details. epsilon2 is the term used to avoid having too small a weight update when applying parameter scaling. (default: (None, 1e-3))

  • d (float, optional) – the clipping threshold, used to avoid larger-than-desired updates.

  • weight_decay (float, optional) – weight decay coefficient (default: 1e-2)

  • foreach (bool, optional) – whether foreach implementation of optimizer is used. Note that the foreach implementation uses ~ sizeof(params) more peak memory than the for-loop version due to the intermediates being a tensorlist vs just one tensor. As Adafactor is commonly used when memory is prohibitive, Adafactor will default to the slower single tensor for-loop implementation unless this flag is explicitly True. This behavior is contrary to other optimizers, which will attempt defaulting to foreach on CUDA for faster runtime. (default: None)

  • maximize (bool, optional) – maximize the objective with respect to the params, instead of minimizing (default: False)

Note

The implementation of Adafactor subtly differs from Shazeer, Noam, and Mitchell Stern and implementations in some other frameworks with its use of learning rate and ϵ1\epsilon_1.

Regarding the learning rate hyperparameter: Shazeer, Noam, and Mitchell Stern do not use lr at all, as the stated algorithm uses ρt\rho_t and update clipping to affect the step size.

This implementation allows lr to influence the maximum value for ρt\rho_t:

ρtmin(lr,1t)\begin{aligned} &\hspace{5mm}\rho_t \leftarrow min(lr, \frac{1}{\sqrt{t}}) \end{aligned}

This differs from Shazeer, Noam, and Mitchell Stern, who use a constant of 0.01 as the maximum value of ρt\rho_t

ρtmin(0.01,1t)\begin{aligned} &\hspace{5mm}\rho_t \leftarrow min(0.01, \frac{1}{\sqrt{t}}) \end{aligned}

Shazeer, Noam, and Mitchell Stern do not enforce an opinion on how weight decay should be computed, and so we use the learning rate as a coefficient for decoupled weight decay, similar to what is suggested in Decoupled Weight Decay Regularization.

Regarding the use of ϵ1\epsilon_1: The implementation attempts to replicate the presumed intention of Shazeer, Noam, and Mitchell Stern to use ϵ1\epsilon_1 as a stabilizing term when the squared gradient becomes small.

This stabilization can be written as

Rtβ^2tRt1+(1β^2t)(GtGt+1n1m)1mCtβ^2tCt1+(1β^2t)1n(GtGt+1n1m)V^tRtCtmax(1nRt,ϵ1)UtGtmax(V^t,ϵ1)\begin{aligned} &\hspace{5mm}R_t \leftarrow \widehat{\beta}_{2_t}R_{t-1}+ (1-\widehat{\beta}_{2_t})(G_t \odot G_t + 1_n \cdot 1^\top_m) \cdot 1_m \\ &\hspace{5mm}C_t \leftarrow \widehat{\beta}_{2_t}C_{t-1}+ (1-\widehat{\beta}_{2_t}) 1^\top_n \cdot (G_t \odot G_t + 1_n \cdot 1^\top_m) \\ &\hspace{5mm}\widehat{V}_t \leftarrow \frac{R_t \cdot C_t}{max(1^\top_n \cdot R_t, \epsilon_1)} \\ &\hspace{5mm}U_t \leftarrow \frac{G_t}{max(\sqrt{\widehat{V}_t}, \epsilon_1)} \\ \end{aligned}

where the row and column factors of gradient squared RtR_t and CtC_t are left alone, and we apply ϵ1\epsilon_1 at the final calculation of the variance estimate V^t\widehat{V}_t and for the update UtU_t.

This is in contrast to Shazeer, Noam, and Mitchell Stern and other frameworks which apply ϵ1\epsilon_1 to both row and column factors of the squared gradient, but not in the calculations after:

Rtβ^2tRt1+(1β^2t)(GtGt+ϵ11n1m)1mCtβ^2tCt1+(1β^2t)1n(GtGt+ϵ11n1m)V^tRtCt1nRtUtGtV^t\begin{aligned} &\hspace{5mm}R_t \leftarrow \widehat{\beta}_{2_t}R_{t-1}+ (1-\widehat{\beta}_{2_t})(G_t \odot G_t + \epsilon_1 1_n \cdot 1^\top_m) \cdot 1_m \\ &\hspace{5mm}C_t \leftarrow \widehat{\beta}_{2_t}C_{t-1}+ (1-\widehat{\beta}_{2_t}) 1^\top_n \cdot (G_t \odot G_t + \epsilon_1 1_n \cdot 1^\top_m) \\ &\hspace{5mm}\widehat{V}_t \leftarrow \frac{R_t \cdot C_t}{1^\top_n \cdot R_t} \\ &\hspace{5mm}U_t \leftarrow \frac{G_t}{\sqrt{\widehat{V}_t}} \\ \end{aligned}
add_param_group(param_group)

Add a param group to the Optimizer s param_groups.

This can be useful when fine tuning a pre-trained network as frozen layers can be made trainable and added to the Optimizer as training progresses.

Parameters

param_group (dict) – Specifies what Tensors should be optimized along with group specific optimization options.

load_state_dict(state_dict)

Load the optimizer state.

Parameters

state_dict (dict) – optimizer state. Should be an object returned from a call to state_dict().

register_load_state_dict_post_hook(hook, prepend=False)

Register a load_state_dict post-hook which will be called after load_state_dict() is called. It should have the following signature:

hook(optimizer) -> None

The optimizer argument is the optimizer instance being used.

The hook will be called with argument self after calling load_state_dict on self. The registered hook can be used to perform post-processing after load_state_dict has loaded the state_dict.

Parameters
  • hook (Callable) – The user defined hook to be registered.

  • prepend (bool) – If True, the provided post hook will be fired before all the already registered post-hooks on load_state_dict. Otherwise, the provided hook will be fired after all the already registered post-hooks. (default: False)

Returns

a handle that can be used to remove the added hook by calling handle.remove()

Return type

torch.utils.hooks.RemoveableHandle

register_load_state_dict_pre_hook(hook, prepend=False)

Register a load_state_dict pre-hook which will be called before load_state_dict() is called. It should have the following signature:

hook(optimizer, state_dict) -> state_dict or None

The optimizer argument is the optimizer instance being used and the state_dict argument is a shallow copy of the state_dict the user passed in to load_state_dict. The hook may modify the state_dict inplace or optionally return a new one. If a state_dict is returned, it will be used to be loaded into the optimizer.

The hook will be called with argument self and state_dict before calling load_state_dict on self. The registered hook can be used to perform pre-processing before the load_state_dict call is made.

Parameters
  • hook (Callable) – The user defined hook to be registered.

  • prepend (bool) – If True, the provided pre hook will be fired before all the already registered pre-hooks on load_state_dict. Otherwise, the provided hook will be fired after all the already registered pre-hooks. (default: False)

Returns

a handle that can be used to remove the added hook by calling handle.remove()

Return type

torch.utils.hooks.RemoveableHandle

register_state_dict_post_hook(hook, prepend=False)

Register a state dict post-hook which will be called after state_dict() is called.

It should have the following signature:

hook(optimizer, state_dict) -> state_dict or None

The hook will be called with arguments self and state_dict after generating a state_dict on self. The hook may modify the state_dict inplace or optionally return a new one. The registered hook can be used to perform post-processing on the state_dict before it is returned.

Parameters
  • hook (Callable) – The user defined hook to be registered.

  • prepend (bool) – If True, the provided post hook will be fired before all the already registered post-hooks on state_dict. Otherwise, the provided hook will be fired after all the already registered post-hooks. (default: False)

Returns

a handle that can be used to remove the added hook by calling handle.remove()

Return type

torch.utils.hooks.RemoveableHandle

register_state_dict_pre_hook(hook, prepend=False)

Register a state dict pre-hook which will be called before state_dict() is called.

It should have the following signature:

hook(optimizer) -> None

The optimizer argument is the optimizer instance being used. The hook will be called with argument self before calling state_dict on self. The registered hook can be used to perform pre-processing before the state_dict call is made.

Parameters
  • hook (Callable) – The user defined hook to be registered.

  • prepend (bool) – If True, the provided pre hook will be fired before all the already registered pre-hooks on state_dict. Otherwise, the provided hook will be fired after all the already registered pre-hooks. (default: False)

Returns

a handle that can be used to remove the added hook by calling handle.remove()

Return type

torch.utils.hooks.RemoveableHandle

register_step_post_hook(hook)

Register an optimizer step post hook which will be called after optimizer step.

It should have the following signature:

hook(optimizer, args, kwargs) -> None

The optimizer argument is the optimizer instance being used.

Parameters

hook (Callable) – The user defined hook to be registered.

Returns

a handle that can be used to remove the added hook by calling handle.remove()

Return type

torch.utils.hooks.RemovableHandle

register_step_pre_hook(hook)

Register an optimizer step pre hook which will be called before optimizer step.

It should have the following signature:

hook(optimizer, args, kwargs) -> None or modified args and kwargs

The optimizer argument is the optimizer instance being used. If args and kwargs are modified by the pre-hook, then the transformed values are returned as a tuple containing the new_args and new_kwargs.

Parameters

hook (Callable) – The user defined hook to be registered.

Returns

a handle that can be used to remove the added hook by calling handle.remove()

Return type

torch.utils.hooks.RemovableHandle

state_dict()

Return the state of the optimizer as a dict.

It contains two entries:

  • state: a Dict holding current optimization state. Its content

    differs between optimizer classes, but some common characteristics hold. For example, state is saved per parameter, and the parameter itself is NOT saved. state is a Dictionary mapping parameter ids to a Dict with state corresponding to each parameter.

  • param_groups: a List containing all parameter groups where each

    parameter group is a Dict. Each parameter group contains metadata specific to the optimizer, such as learning rate and weight decay, as well as a List of parameter IDs of the parameters in the group.

NOTE: The parameter IDs may look like indices but they are just IDs associating state with param_group. When loading from a state_dict, the optimizer will zip the param_group params (int IDs) and the optimizer param_groups (actual nn.Parameter s) in order to match state WITHOUT additional verification.

A returned state dict might look something like:

{
    'state': {
        0: {'momentum_buffer': tensor(...), ...},
        1: {'momentum_buffer': tensor(...), ...},
        2: {'momentum_buffer': tensor(...), ...},
        3: {'momentum_buffer': tensor(...), ...}
    },
    'param_groups': [
        {
            'lr': 0.01,
            'weight_decay': 0,
            ...
            'params': [0]
        },
        {
            'lr': 0.001,
            'weight_decay': 0.5,
            ...
            'params': [1, 2, 3]
        }
    ]
}
Return type

Dict[str, Any]

step(closure=None)[source]

Perform a single optimization step.

Parameters

closure (Callable, optional) – A closure that reevaluates the model and returns the loss.

zero_grad(set_to_none=True)

Reset the gradients of all optimized torch.Tensor s.

Parameters

set_to_none (bool) – instead of setting to zero, set the grads to None. This will in general have lower memory footprint, and can modestly improve performance. However, it changes certain behaviors. For example: 1. When the user tries to access a gradient and perform manual ops on it, a None attribute or a Tensor full of 0s will behave differently. 2. If the user requests zero_grad(set_to_none=True) followed by a backward pass, .grads are guaranteed to be None for params that did not receive a gradient. 3. torch.optim optimizers have a different behavior if the gradient is 0 or None (in one case it does the step with a gradient of 0 and in the other it skips the step altogether).

Docs

Access comprehensive developer documentation for PyTorch

View Docs

Tutorials

Get in-depth tutorials for beginners and advanced developers

View Tutorials

Resources

Find development resources and get your questions answered

View Resources