Shortcuts

RMSprop

class torch.optim.RMSprop(params, lr=0.01, alpha=0.99, eps=1e-08, weight_decay=0, momentum=0, centered=False, foreach=None, maximize=False, differentiable=False)[source]

Implements RMSprop algorithm.

input:α (alpha),γ (lr),θ0 (params),f(θ) (objective)λ (weight decay),μ (momentum),centeredinitialize:v00 (square average),b00 (buffer),g0ave0fort=1todogtθft(θt1)ifλ0gtgt+λθt1vtαvt1+(1α)gt2vt~vtifcenteredgtavegt1aveα+(1α)gtvt~vt~(gtave)2ifμ>0btμbt1+gt/(vt~+ϵ)θtθt1γbtelseθtθt1γgt/(vt~+ϵ)returnθt\begin{aligned} &\rule{110mm}{0.4pt} \\ &\textbf{input} : \alpha \text{ (alpha)},\: \gamma \text{ (lr)}, \: \theta_0 \text{ (params)}, \: f(\theta) \text{ (objective)} \\ &\hspace{13mm} \lambda \text{ (weight decay)},\: \mu \text{ (momentum)},\: centered\\ &\textbf{initialize} : v_0 \leftarrow 0 \text{ (square average)}, \: \textbf{b}_0 \leftarrow 0 \text{ (buffer)}, \: g^{ave}_0 \leftarrow 0 \\[-1.ex] &\rule{110mm}{0.4pt} \\ &\textbf{for} \: t=1 \: \textbf{to} \: \ldots \: \textbf{do} \\ &\hspace{5mm}g_t \leftarrow \nabla_{\theta} f_t (\theta_{t-1}) \\ &\hspace{5mm}if \: \lambda \neq 0 \\ &\hspace{10mm} g_t \leftarrow g_t + \lambda \theta_{t-1} \\ &\hspace{5mm}v_t \leftarrow \alpha v_{t-1} + (1 - \alpha) g^2_t \hspace{8mm} \\ &\hspace{5mm} \tilde{v_t} \leftarrow v_t \\ &\hspace{5mm}if \: centered \\ &\hspace{10mm} g^{ave}_t \leftarrow g^{ave}_{t-1} \alpha + (1-\alpha) g_t \\ &\hspace{10mm} \tilde{v_t} \leftarrow \tilde{v_t} - \big(g^{ave}_{t} \big)^2 \\ &\hspace{5mm}if \: \mu > 0 \\ &\hspace{10mm} \textbf{b}_t\leftarrow \mu \textbf{b}_{t-1} + g_t/ \big(\sqrt{\tilde{v_t}} + \epsilon \big) \\ &\hspace{10mm} \theta_t \leftarrow \theta_{t-1} - \gamma \textbf{b}_t \\ &\hspace{5mm} else \\ &\hspace{10mm}\theta_t \leftarrow \theta_{t-1} - \gamma g_t/ \big(\sqrt{\tilde{v_t}} + \epsilon \big) \hspace{3mm} \\ &\rule{110mm}{0.4pt} \\[-1.ex] &\bf{return} \: \theta_t \\[-1.ex] &\rule{110mm}{0.4pt} \\[-1.ex] \end{aligned}

For further details regarding the algorithm we refer to lecture notes by G. Hinton. and centered version Generating Sequences With Recurrent Neural Networks. The implementation here takes the square root of the gradient average before adding epsilon (note that TensorFlow interchanges these two operations). The effective learning rate is thus γ/(v+ϵ)\gamma/(\sqrt{v} + \epsilon) where γ\gamma is the scheduled learning rate and vv is the weighted moving average of the squared gradient.

Parameters:
  • params (iterable) – iterable of parameters to optimize or dicts defining parameter groups

  • lr (float, optional) – learning rate (default: 1e-2)

  • momentum (float, optional) – momentum factor (default: 0)

  • alpha (float, optional) – smoothing constant (default: 0.99)

  • eps (float, optional) – term added to the denominator to improve numerical stability (default: 1e-8)

  • centered (bool, optional) – if True, compute the centered RMSProp, the gradient is normalized by an estimation of its variance

  • weight_decay (float, optional) – weight decay (L2 penalty) (default: 0)

  • foreach (bool, optional) – whether foreach implementation of optimizer is used. If unspecified by the user (so foreach is None), we will try to use foreach over the for-loop implementation on CUDA, since it is usually significantly more performant. (default: None)

  • maximize (bool, optional) – maximize the params based on the objective, instead of minimizing (default: False)

  • differentiable (bool, optional) – whether autograd should occur through the optimizer step in training. Otherwise, the step() function runs in a torch.no_grad() context. Setting to True can impair performance, so leave it False if you don’t intend to run autograd through this instance (default: False)

add_param_group(param_group)

Add a param group to the Optimizer s param_groups.

This can be useful when fine tuning a pre-trained network as frozen layers can be made trainable and added to the Optimizer as training progresses.

Parameters:

param_group (dict) – Specifies what Tensors should be optimized along with group specific optimization options.

load_state_dict(state_dict)

Loads the optimizer state.

Parameters:

state_dict (dict) – optimizer state. Should be an object returned from a call to state_dict().

register_step_post_hook(hook)

Register an optimizer step post hook which will be called after optimizer step. It should have the following signature:

hook(optimizer, args, kwargs) -> None

The optimizer argument is the optimizer instance being used.

Parameters:

hook (Callable) – The user defined hook to be registered.

Returns:

a handle that can be used to remove the added hook by calling handle.remove()

Return type:

torch.utils.hooks.RemoveableHandle

register_step_pre_hook(hook)

Register an optimizer step pre hook which will be called before optimizer step. It should have the following signature:

hook(optimizer, args, kwargs) -> None or modified args and kwargs

The optimizer argument is the optimizer instance being used. If args and kwargs are modified by the pre-hook, then the transformed values are returned as a tuple containing the new_args and new_kwargs.

Parameters:

hook (Callable) – The user defined hook to be registered.

Returns:

a handle that can be used to remove the added hook by calling handle.remove()

Return type:

torch.utils.hooks.RemoveableHandle

state_dict()

Returns the state of the optimizer as a dict.

It contains two entries:

  • state - a dict holding current optimization state. Its content

    differs between optimizer classes.

  • param_groups - a list containing all parameter groups where each

    parameter group is a dict

zero_grad(set_to_none=True)

Sets the gradients of all optimized torch.Tensor s to zero.

Parameters:

set_to_none (bool) – instead of setting to zero, set the grads to None. This will in general have lower memory footprint, and can modestly improve performance. However, it changes certain behaviors. For example: 1. When the user tries to access a gradient and perform manual ops on it, a None attribute or a Tensor full of 0s will behave differently. 2. If the user requests zero_grad(set_to_none=True) followed by a backward pass, .grads are guaranteed to be None for params that did not receive a gradient. 3. torch.optim optimizers have a different behavior if the gradient is 0 or None (in one case it does the step with a gradient of 0 and in the other it skips the step altogether).

Docs

Access comprehensive developer documentation for PyTorch

View Docs

Tutorials

Get in-depth tutorials for beginners and advanced developers

View Tutorials

Resources

Find development resources and get your questions answered

View Resources