Source code for torch.optim.rmsprop
# mypy: allow-untyped-decorators
# mypy: allow-untyped-defs
r"""Implementation for the RMSprop algorithm."""
from typing import cast, List, Optional, Union
import torch
from torch import Tensor
from .optimizer import (
_capturable_doc,
_default_to_fused_or_foreach,
_differentiable_doc,
_disable_dynamo_if_unsupported,
_foreach_doc,
_get_capturable_supported_devices,
_get_scalar_dtype,
_maximize_doc,
_use_grad_for_differentiable,
_view_as_real,
Optimizer,
ParamsT,
)
__all__ = ["RMSprop", "rmsprop"]
[docs]class RMSprop(Optimizer): # noqa: D101
def __init__(
self,
params: ParamsT,
lr: Union[float, Tensor] = 1e-2,
alpha: float = 0.99,
eps: float = 1e-8,
weight_decay: float = 0,
momentum: float = 0,
centered=False,
capturable=False,
foreach: Optional[bool] = None,
maximize: bool = False,
differentiable: bool = False,
): # noqa: D107
if isinstance(lr, Tensor) and lr.numel() != 1:
raise ValueError("Tensor lr must be 1-element")
if not 0.0 <= lr:
raise ValueError(f"Invalid learning rate: {lr}")
if not 0.0 <= eps:
raise ValueError(f"Invalid epsilon value: {eps}")
if not 0.0 <= momentum:
raise ValueError(f"Invalid momentum value: {momentum}")
if not 0.0 <= weight_decay:
raise ValueError(f"Invalid weight_decay value: {weight_decay}")
if not 0.0 <= alpha:
raise ValueError(f"Invalid alpha value: {alpha}")
defaults = dict(
lr=lr,
momentum=momentum,
alpha=alpha,
eps=eps,
centered=centered,
weight_decay=weight_decay,
capturable=capturable,
foreach=foreach,
maximize=maximize,
differentiable=differentiable,
)
super().__init__(params, defaults)
def __setstate__(self, state): # noqa: D105
super().__setstate__(state)
for group in self.param_groups:
group.setdefault("momentum", 0)
group.setdefault("centered", False)
group.setdefault("foreach", None)
group.setdefault("maximize", False)
group.setdefault("differentiable", False)
group.setdefault("capturable", False)
for p in group["params"]:
p_state = self.state.get(p, [])
if len(p_state) != 0 and not torch.is_tensor(p_state["step"]):
step_val = float(p_state["step"])
p_state["step"] = (
torch.tensor(
step_val, dtype=_get_scalar_dtype(), device=p.device
)
if group["capturable"]
else torch.tensor(step_val, dtype=_get_scalar_dtype())
)
def _init_group(
self,
group,
params_with_grad,
grads,
square_avgs,
momentum_buffer_list,
grad_avgs,
state_steps,
):
has_complex = False
for p in group["params"]:
if p.grad is None:
continue
has_complex |= torch.is_complex(p)
params_with_grad.append(p)
if p.grad.is_sparse:
raise RuntimeError("RMSprop does not support sparse gradients")
grads.append(p.grad)
state = self.state[p]
# State initialization
if len(state) == 0:
state["step"] = (
torch.zeros((), dtype=_get_scalar_dtype(), device=p.device)
if group["capturable"]
else torch.zeros((), dtype=_get_scalar_dtype())
)
state["square_avg"] = torch.zeros_like(
p, memory_format=torch.preserve_format
)
if group["momentum"] > 0:
state["momentum_buffer"] = torch.zeros_like(
p, memory_format=torch.preserve_format
)
if group["centered"]:
state["grad_avg"] = torch.zeros_like(
p, memory_format=torch.preserve_format
)
square_avgs.append(state["square_avg"])
state_steps.append(state["step"])
if group["momentum"] > 0:
momentum_buffer_list.append(state["momentum_buffer"])
if group["centered"]:
grad_avgs.append(state["grad_avg"])
return has_complex
[docs] @_use_grad_for_differentiable
def step(self, closure=None):
"""Perform a single optimization step.
Args:
closure (Callable, optional): A closure that reevaluates the model
and returns the loss.
"""
self._cuda_graph_capture_health_check()
loss = None
if closure is not None:
with torch.enable_grad():
loss = closure()
for group in self.param_groups:
params_with_grad: List[Tensor] = []
grads: List[Tensor] = []
square_avgs: List[Tensor] = []
grad_avgs: List[Tensor] = []
momentum_buffer_list: List[Tensor] = []
state_steps: List[Tensor] = []
has_complex = self._init_group(
group,
params_with_grad,
grads,
square_avgs,
momentum_buffer_list,
grad_avgs,
state_steps,
)
rmsprop(
params_with_grad,
grads,
square_avgs,
grad_avgs,
momentum_buffer_list,
state_steps,
lr=group["lr"],
alpha=group["alpha"],
eps=group["eps"],
weight_decay=group["weight_decay"],
momentum=group["momentum"],
centered=group["centered"],
foreach=group["foreach"],
maximize=group["maximize"],
differentiable=group["differentiable"],
capturable=group["capturable"],
has_complex=has_complex,
)
return loss
RMSprop.__doc__ = (
r"""Implements RMSprop algorithm.
.. math::
\begin{aligned}
&\rule{110mm}{0.4pt} \\
&\textbf{input} : \alpha \text{ (alpha)},\: \gamma \text{ (lr)},
\: \theta_0 \text{ (params)}, \: f(\theta) \text{ (objective)} \\
&\hspace{13mm} \lambda \text{ (weight decay)},\: \mu \text{ (momentum)},\: centered\\
&\textbf{initialize} : v_0 \leftarrow 0 \text{ (square average)}, \:
\textbf{b}_0 \leftarrow 0 \text{ (buffer)}, \: g^{ave}_0 \leftarrow 0 \\[-1.ex]
&\rule{110mm}{0.4pt} \\
&\textbf{for} \: t=1 \: \textbf{to} \: \ldots \: \textbf{do} \\
&\hspace{5mm}g_t \leftarrow \nabla_{\theta} f_t (\theta_{t-1}) \\
&\hspace{5mm}if \: \lambda \neq 0 \\
&\hspace{10mm} g_t \leftarrow g_t + \lambda \theta_{t-1} \\
&\hspace{5mm}v_t \leftarrow \alpha v_{t-1} + (1 - \alpha) g^2_t
\hspace{8mm} \\
&\hspace{5mm} \tilde{v_t} \leftarrow v_t \\
&\hspace{5mm}if \: centered \\
&\hspace{10mm} g^{ave}_t \leftarrow g^{ave}_{t-1} \alpha + (1-\alpha) g_t \\
&\hspace{10mm} \tilde{v_t} \leftarrow \tilde{v_t} - \big(g^{ave}_{t} \big)^2 \\
&\hspace{5mm}if \: \mu > 0 \\
&\hspace{10mm} \textbf{b}_t\leftarrow \mu \textbf{b}_{t-1} +
g_t/ \big(\sqrt{\tilde{v_t}} + \epsilon \big) \\
&\hspace{10mm} \theta_t \leftarrow \theta_{t-1} - \gamma \textbf{b}_t \\
&\hspace{5mm} else \\
&\hspace{10mm}\theta_t \leftarrow \theta_{t-1} -
\gamma g_t/ \big(\sqrt{\tilde{v_t}} + \epsilon \big) \hspace{3mm} \\
&\rule{110mm}{0.4pt} \\[-1.ex]
&\bf{return} \: \theta_t \\[-1.ex]
&\rule{110mm}{0.4pt} \\[-1.ex]
\end{aligned}
For further details regarding the algorithm we refer to
`lecture notes <https://www.cs.toronto.edu/~tijmen/csc321/slides/lecture_slides_lec6.pdf>`_ by G. Hinton.
and centered version `Generating Sequences
With Recurrent Neural Networks <https://arxiv.org/pdf/1308.0850v5.pdf>`_.
The implementation here takes the square root of the gradient average before
adding epsilon (note that TensorFlow interchanges these two operations). The effective
learning rate is thus :math:`\gamma/(\sqrt{v} + \epsilon)` where :math:`\gamma`
is the scheduled learning rate and :math:`v` is the weighted moving average
of the squared gradient.
"""
+ rf"""
Args:
params (iterable): iterable of parameters to optimize or dicts defining
parameter groups
lr (float, Tensor, optional): learning rate (default: 1e-2)
momentum (float, optional): momentum factor (default: 0)
alpha (float, optional): smoothing constant (default: 0.99)
eps (float, optional): term added to the denominator to improve
numerical stability (default: 1e-8)
centered (bool, optional) : if ``True``, compute the centered RMSProp,
the gradient is normalized by an estimation of its variance
weight_decay (float, optional): weight decay (L2 penalty) (default: 0)
{_foreach_doc}
{_maximize_doc}
{_capturable_doc}
{_differentiable_doc}
"""
)
def _single_tensor_rmsprop(
params: List[Tensor],
grads: List[Tensor],
square_avgs: List[Tensor],
grad_avgs: List[Tensor],
momentum_buffer_list: List[Tensor],
state_steps: List[Tensor],
*,
lr: float,
alpha: float,
eps: float,
weight_decay: float,
momentum: float,
centered: bool,
maximize: bool,
differentiable: bool,
capturable: bool,
has_complex: bool,
):
for i, param in enumerate(params):
step = state_steps[i]
# If compiling, the compiler will handle cudagraph checks, see note [torch.compile x capturable]
if not torch._utils.is_compiling() and capturable:
capturable_supported_devices = _get_capturable_supported_devices()
assert (
param.device.type == step.device.type
and param.device.type in capturable_supported_devices
), f"If capturable=True, params and state_steps must be on supported devices: {capturable_supported_devices}."
grad = grads[i]
grad = grad if not maximize else -grad
square_avg = square_avgs[i]
step += 1
if weight_decay != 0:
grad = grad.add(param, alpha=weight_decay)
is_complex_param = torch.is_complex(param)
if is_complex_param:
param = torch.view_as_real(param)
grad = torch.view_as_real(grad)
square_avg = torch.view_as_real(square_avg)
square_avg.mul_(alpha).addcmul_(grad, grad, value=1 - alpha)
if centered:
grad_avg = grad_avgs[i]
if is_complex_param:
grad_avg = torch.view_as_real(grad_avg)
grad_avg.lerp_(grad, 1 - alpha)
avg = square_avg.addcmul(grad_avg, grad_avg, value=-1).sqrt_()
else:
avg = square_avg.sqrt()
if differentiable:
avg = avg.add(eps)
else:
avg = avg.add_(eps)
if momentum > 0:
buf = momentum_buffer_list[i]
if is_complex_param:
buf = torch.view_as_real(buf)
buf.mul_(momentum).addcdiv_(grad, avg)
param.add_(buf, alpha=-lr)
else:
param.addcdiv_(grad, avg, value=-lr)
def _multi_tensor_rmsprop(
params: List[Tensor],
grads: List[Tensor],
square_avgs: List[Tensor],
grad_avgs: List[Tensor],
momentum_buffer_list: List[Tensor],
state_steps: List[Tensor],
*,
lr: float,
alpha: float,
eps: float,
weight_decay: float,
momentum: float,
centered: bool,
maximize: bool,
differentiable: bool,
capturable: bool,
has_complex: bool,
):
if len(params) == 0:
return
assert not differentiable, "_foreach ops don't support autograd"
# If compiling, the compiler will handle cudagraph checks, see note [torch.compile x capturable]
if not torch._utils.is_compiling() and capturable:
capturable_supported_devices = _get_capturable_supported_devices()
assert all(
p.device.type == step.device.type
and p.device.type in capturable_supported_devices
for p, step in zip(params, state_steps)
), f"If capturable=True, params and state_steps must be on supported devices: {capturable_supported_devices}."
grouped_tensors = Optimizer._group_tensors_by_device_and_dtype(
[params, grads, square_avgs, grad_avgs, momentum_buffer_list, state_steps] # type: ignore[list-item]
)
for (
(
grouped_params_,
grouped_grads_,
grouped_square_avgs_,
grouped_grad_avgs_,
grouped_momentum_buffer_list_,
grouped_state_steps_,
)
), _ in grouped_tensors.values():
grouped_params = cast(List[Tensor], grouped_params_)
grouped_grads = cast(List[Tensor], grouped_grads_)
grouped_square_avgs = cast(List[Tensor], grouped_square_avgs_)
grouped_state_steps = cast(List[Tensor], grouped_state_steps_)
if has_complex:
state_and_grads = [grouped_grads, grouped_square_avgs]
if momentum > 0:
grouped_momentum_buffer_list = cast(
List[Tensor], grouped_momentum_buffer_list_
)
state_and_grads.append(grouped_momentum_buffer_list)
if centered:
grouped_grad_avgs = cast(List[Tensor], grouped_grad_avgs_)
state_and_grads.append(grouped_grad_avgs)
_view_as_real(grouped_params, *state_and_grads)
if maximize:
grouped_grads = torch._foreach_neg(grouped_grads) # type: ignore[assignment]
# Update steps
# If steps are on CPU, foreach will fall back to the slow path, which is a for-loop calling t.add(1) over
# and over. 1 will then be wrapped into a Tensor over and over again, which is slower than if we just
# wrapped it once now. The alpha is required to assure we go to the right overload.
if not torch._utils.is_compiling() and grouped_state_steps[0].is_cpu:
torch._foreach_add_(
grouped_state_steps, torch.tensor(1.0, device="cpu"), alpha=1.0
)
else:
torch._foreach_add_(grouped_state_steps, 1)
if weight_decay != 0:
# Re-use the intermediate memory (grouped_grads) already allocated for maximize
if maximize:
torch._foreach_add_(grouped_grads, grouped_params, alpha=weight_decay)
else:
grouped_grads = torch._foreach_add( # type: ignore[assignment]
grouped_grads, grouped_params, alpha=weight_decay
)
torch._foreach_mul_(grouped_square_avgs, alpha)
torch._foreach_addcmul_(
grouped_square_avgs, grouped_grads, grouped_grads, value=1 - alpha
)
if centered:
grouped_grad_avgs = cast(List[Tensor], grouped_grad_avgs_)
torch._foreach_lerp_(grouped_grad_avgs, grouped_grads, 1 - alpha)
avg = torch._foreach_addcmul(
grouped_square_avgs, grouped_grad_avgs, grouped_grad_avgs, value=-1
)
torch._foreach_sqrt_(avg)
torch._foreach_add_(avg, eps)
else:
avg = torch._foreach_sqrt(grouped_square_avgs)
torch._foreach_add_(avg, eps)
if momentum > 0:
grouped_momentum_buffer_list = cast(
List[Tensor], grouped_momentum_buffer_list_
)
torch._foreach_mul_(grouped_momentum_buffer_list, momentum)
torch._foreach_addcdiv_(grouped_momentum_buffer_list, grouped_grads, avg)
# If LR is a tensor, the else branch will internally call item()
# which will cause silent incorrectness if we are capturing
if capturable and isinstance(lr, torch.Tensor):
momentum_lr = torch._foreach_mul(grouped_momentum_buffer_list, -lr)
torch._foreach_add_(grouped_params, momentum_lr)
else:
torch._foreach_add_(
grouped_params, grouped_momentum_buffer_list, alpha=-lr
)
else:
# If LR is a tensor, the else branch will internally call item()
# which will cause silent incorrectness if we are capturing
if capturable and isinstance(lr, torch.Tensor):
torch._foreach_div_(avg, -lr)
torch._foreach_addcdiv_(grouped_params, grouped_grads, avg)
else:
torch._foreach_addcdiv_(grouped_params, grouped_grads, avg, value=-lr)
@_disable_dynamo_if_unsupported(single_tensor_fn=_single_tensor_rmsprop)
def rmsprop(
params: List[Tensor],
grads: List[Tensor],
square_avgs: List[Tensor],
grad_avgs: List[Tensor],
momentum_buffer_list: List[Tensor],
state_steps: List[Tensor],
# kwonly args with defaults are not supported by functions compiled with torchscript issue #70627
# setting this as kwarg for now as functional API is compiled by torch/distributed/optim
foreach: Optional[bool] = None,
maximize: bool = False,
differentiable: bool = False,
capturable: bool = False,
has_complex: bool = False,
*,
lr: float,
alpha: float,
eps: float,
weight_decay: float,
momentum: float,
centered: bool,
):
r"""Functional API that performs rmsprop algorithm computation.
See :class:`~torch.optim.RMSProp` for details.
"""
# this check is slow during compilation, so we skip it
# if it's strictly needed we can add this check back in dynamo
if not torch._utils.is_compiling() and not all(
isinstance(t, torch.Tensor) for t in state_steps
):
raise RuntimeError(
"API has changed, `state_steps` argument must contain a list of singleton tensors"
)
if foreach is None:
_, foreach = _default_to_fused_or_foreach(
params, differentiable, use_fused=False
)
if foreach and torch.jit.is_scripting():
raise RuntimeError("torch.jit.script not supported with foreach optimizers")
if foreach and not torch.jit.is_scripting():
func = _multi_tensor_rmsprop
else:
func = _single_tensor_rmsprop
func(
params,
grads,
square_avgs,
grad_avgs,
momentum_buffer_list,
state_steps,
lr=lr,
alpha=alpha,
eps=eps,
weight_decay=weight_decay,
momentum=momentum,
centered=centered,
maximize=maximize,
capturable=capturable,
differentiable=differentiable,
has_complex=has_complex,
)