[docs]classAdamW(Adam):def__init__(self,params:ParamsT,lr:Union[float,Tensor]=1e-3,betas:tuple[Union[float,Tensor],Union[float,Tensor]]=(0.9,0.999),eps:float=1e-8,weight_decay:float=1e-2,amsgrad:bool=False,*,maximize:bool=False,foreach:Optional[bool]=None,capturable:bool=False,differentiable:bool=False,fused:Optional[bool]=None,):super().__init__(params,lr,betas,eps,weight_decay,amsgrad,foreach=foreach,maximize=maximize,capturable=capturable,differentiable=differentiable,fused=fused,decoupled_weight_decay=True,)# Preserve decoupled_weight_decay from AdamW for backwards compatibility. The following# guarantees that decoupled_weight_decay will always be True for loading any state into# AdamWdef__setstate__(self,state):super().__setstate__(state)forgroupinself.param_groups:group["decoupled_weight_decay"]=True
AdamW.__doc__=(r"""Implements AdamW algorithm, where weight decay does not accumulate in the momentum nor variance. .. math:: \begin{aligned} &\rule{110mm}{0.4pt} \\ &\textbf{input} : \gamma \text{(lr)}, \: \beta_1, \beta_2 \text{(betas)}, \: \theta_0 \text{(params)}, \: f(\theta) \text{(objective)}, \: \epsilon \text{ (epsilon)} \\ &\hspace{13mm} \lambda \text{(weight decay)}, \: \textit{amsgrad}, \: \textit{maximize} \\ &\textbf{initialize} : m_0 \leftarrow 0 \text{ (first moment)}, v_0 \leftarrow 0 \text{ ( second moment)}, \: v_0^{max}\leftarrow 0 \\[-1.ex] &\rule{110mm}{0.4pt} \\ &\textbf{for} \: t=1 \: \textbf{to} \: \ldots \: \textbf{do} \\ &\hspace{5mm}\textbf{if} \: \textit{maximize}: \\ &\hspace{10mm}g_t \leftarrow -\nabla_{\theta} f_t (\theta_{t-1}) \\ &\hspace{5mm}\textbf{else} \\ &\hspace{10mm}g_t \leftarrow \nabla_{\theta} f_t (\theta_{t-1}) \\ &\hspace{5mm} \theta_t \leftarrow \theta_{t-1} - \gamma \lambda \theta_{t-1} \\ &\hspace{5mm}m_t \leftarrow \beta_1 m_{t-1} + (1 - \beta_1) g_t \\ &\hspace{5mm}v_t \leftarrow \beta_2 v_{t-1} + (1-\beta_2) g^2_t \\ &\hspace{5mm}\widehat{m_t} \leftarrow m_t/\big(1-\beta_1^t \big) \\ &\hspace{5mm}\textbf{if} \: amsgrad \\ &\hspace{10mm} v_t^{max} \leftarrow \mathrm{max}(v_{t-1}^{max},v_t) \\ &\hspace{10mm}\widehat{v_t} \leftarrow v_t^{max}/\big(1-\beta_2^t \big) \\ &\hspace{5mm}\textbf{else} \\ &\hspace{10mm}\widehat{v_t} \leftarrow v_t/\big(1-\beta_2^t \big) \\ &\hspace{5mm}\theta_t \leftarrow \theta_t - \gamma \widehat{m_t}/ \big(\sqrt{\widehat{v_t}} + \epsilon \big) \\ &\rule{110mm}{0.4pt} \\[-1.ex] &\bf{return} \: \theta_t \\[-1.ex] &\rule{110mm}{0.4pt} \\[-1.ex] \end{aligned} For further details regarding the algorithm we refer to `Decoupled Weight Decay Regularization`_. """+rf""" Args:{_params_doc} lr (float, Tensor, optional): learning rate (default: 1e-3). A tensor LR is not yet supported for all our implementations. Please use a float LR if you are not also specifying fused=True or capturable=True. betas (Tuple[float, float], optional): coefficients used for computing running averages of gradient and its square (default: (0.9, 0.999)) eps (float, optional): term added to the denominator to improve numerical stability (default: 1e-8) weight_decay (float, optional): weight decay coefficient (default: 1e-2) amsgrad (bool, optional): whether to use the AMSGrad variant of this algorithm from the paper `On the Convergence of Adam and Beyond`_ (default: False){_maximize_doc}{_foreach_doc}{_capturable_doc}{_differentiable_doc}{_fused_doc} .. Note:: A prototype implementation of Adam and AdamW for MPS supports `torch.float32` and `torch.float16`. .. _Decoupled Weight Decay Regularization: https://arxiv.org/abs/1711.05101 .. _On the Convergence of Adam and Beyond: https://openreview.net/forum?id=ryQu7f-RZ """)# @_disable_dynamo_if_unsupported logic occurs in the decorator that's applied to F.adamdefadamw(params:list[Tensor],grads:list[Tensor],exp_avgs:list[Tensor],exp_avg_sqs:list[Tensor],max_exp_avg_sqs:list[Tensor],state_steps:list[Tensor],# kwonly args with defaults are not supported by functions compiled with torchscript issue #70627# setting this as kwarg for now as functional API is compiled by torch/distributed/optimforeach:Optional[bool]=None,capturable:bool=False,differentiable:bool=False,fused:Optional[bool]=None,grad_scale:Optional[Tensor]=None,found_inf:Optional[Tensor]=None,has_complex:bool=False,*,amsgrad:bool,beta1:float,beta2:float,lr:Union[float,Tensor],weight_decay:float,eps:float,maximize:bool,):r"""Functional API that performs AdamW algorithm computation. See :class:`~torch.optim.AdamW` for details. """adam(params,grads,exp_avgs,exp_avg_sqs,max_exp_avg_sqs,state_steps,foreach=foreach,capturable=capturable,differentiable=differentiable,fused=fused,grad_scale=grad_scale,found_inf=found_inf,has_complex=has_complex,amsgrad=amsgrad,beta1=beta1,beta2=beta2,lr=lr,weight_decay=weight_decay,eps=eps,maximize=maximize,decoupled_weight_decay=True,)
Docs
Access comprehensive developer documentation for PyTorch
To analyze traffic and optimize your experience, we serve cookies on this site. By clicking or navigating, you agree to allow our usage of cookies. As the current maintainers of this site, Facebook’s Cookies Policy applies. Learn more, including about available controls: Cookies Policy.