[docs]classSGD(Optimizer):r"""Implements stochastic gradient descent (optionally with momentum). .. math:: \begin{aligned} &\rule{110mm}{0.4pt} \\ &\textbf{input} : \gamma \text{ (lr)}, \: \theta_0 \text{ (params)}, \: f(\theta) \text{ (objective)}, \: \lambda \text{ (weight decay)}, \\ &\hspace{13mm} \:\mu \text{ (momentum)}, \:\tau \text{ (dampening)}, \:\textit{ nesterov,}\:\textit{ maximize} \\[-1.ex] &\rule{110mm}{0.4pt} \\ &\textbf{for} \: t=1 \: \textbf{to} \: \ldots \: \textbf{do} \\ &\hspace{5mm}g_t \leftarrow \nabla_{\theta} f_t (\theta_{t-1}) \\ &\hspace{5mm}\textbf{if} \: \lambda \neq 0 \\ &\hspace{10mm} g_t \leftarrow g_t + \lambda \theta_{t-1} \\ &\hspace{5mm}\textbf{if} \: \mu \neq 0 \\ &\hspace{10mm}\textbf{if} \: t > 1 \\ &\hspace{15mm} \textbf{b}_t \leftarrow \mu \textbf{b}_{t-1} + (1-\tau) g_t \\ &\hspace{10mm}\textbf{else} \\ &\hspace{15mm} \textbf{b}_t \leftarrow g_t \\ &\hspace{10mm}\textbf{if} \: \textit{nesterov} \\ &\hspace{15mm} g_t \leftarrow g_{t} + \mu \textbf{b}_t \\ &\hspace{10mm}\textbf{else} \\[-1.ex] &\hspace{15mm} g_t \leftarrow \textbf{b}_t \\ &\hspace{5mm}\textbf{if} \: \textit{maximize} \\ &\hspace{10mm}\theta_t \leftarrow \theta_{t-1} + \gamma g_t \\[-1.ex] &\hspace{5mm}\textbf{else} \\[-1.ex] &\hspace{10mm}\theta_t \leftarrow \theta_{t-1} - \gamma g_t \\[-1.ex] &\rule{110mm}{0.4pt} \\[-1.ex] &\bf{return} \: \theta_t \\[-1.ex] &\rule{110mm}{0.4pt} \\[-1.ex] \end{aligned} Nesterov momentum is based on the formula from `On the importance of initialization and momentum in deep learning`__. Args: params (iterable): iterable of parameters to optimize or dicts defining parameter groups lr (float): learning rate momentum (float, optional): momentum factor (default: 0) weight_decay (float, optional): weight decay (L2 penalty) (default: 0) dampening (float, optional): dampening for momentum (default: 0) nesterov (bool, optional): enables Nesterov momentum (default: False) maximize (bool, optional): maximize the params based on the objective, instead of minimizing (default: False) foreach (bool, optional): whether foreach implementation of optimizer is used (default: None) Example: >>> # xdoctest: +SKIP >>> optimizer = torch.optim.SGD(model.parameters(), lr=0.1, momentum=0.9) >>> optimizer.zero_grad() >>> loss_fn(model(input), target).backward() >>> optimizer.step() __ http://www.cs.toronto.edu/%7Ehinton/absps/momentum.pdf .. note:: The implementation of SGD with Momentum/Nesterov subtly differs from Sutskever et. al. and implementations in some other frameworks. Considering the specific case of Momentum, the update can be written as .. math:: \begin{aligned} v_{t+1} & = \mu * v_{t} + g_{t+1}, \\ p_{t+1} & = p_{t} - \text{lr} * v_{t+1}, \end{aligned} where :math:`p`, :math:`g`, :math:`v` and :math:`\mu` denote the parameters, gradient, velocity, and momentum respectively. This is in contrast to Sutskever et. al. and other frameworks which employ an update of the form .. math:: \begin{aligned} v_{t+1} & = \mu * v_{t} + \text{lr} * g_{t+1}, \\ p_{t+1} & = p_{t} - v_{t+1}. \end{aligned} The Nesterov version is analogously modified. """def__init__(self,params,lr=required,momentum=0,dampening=0,weight_decay=0,nesterov=False,*,maximize=False,foreach:Optional[bool]=None,differentiable=False):iflrisnotrequiredandlr<0.0:raiseValueError("Invalid learning rate: {}".format(lr))ifmomentum<0.0:raiseValueError("Invalid momentum value: {}".format(momentum))ifweight_decay<0.0:raiseValueError("Invalid weight_decay value: {}".format(weight_decay))defaults=dict(lr=lr,momentum=momentum,dampening=dampening,weight_decay=weight_decay,nesterov=nesterov,maximize=maximize,foreach=foreach,differentiable=differentiable)ifnesterovand(momentum<=0ordampening!=0):raiseValueError("Nesterov momentum requires a momentum and zero dampening")super(SGD,self).__init__(params,defaults)def__setstate__(self,state):super().__setstate__(state)forgroupinself.param_groups:group.setdefault('nesterov',False)group.setdefault('maximize',False)group.setdefault('foreach',None)group.setdefault('differentiable',False)@_use_grad_for_differentiabledefstep(self,closure=None):"""Performs a single optimization step. Args: closure (Callable, optional): A closure that reevaluates the model and returns the loss. """loss=NoneifclosureisnotNone:withtorch.enable_grad():loss=closure()forgroupinself.param_groups:params_with_grad=[]d_p_list=[]momentum_buffer_list=[]has_sparse_grad=Falseforpingroup['params']:ifp.gradisnotNone:params_with_grad.append(p)d_p_list.append(p.grad)ifp.grad.is_sparse:has_sparse_grad=Truestate=self.state[p]if'momentum_buffer'notinstate:momentum_buffer_list.append(None)else:momentum_buffer_list.append(state['momentum_buffer'])sgd(params_with_grad,d_p_list,momentum_buffer_list,weight_decay=group['weight_decay'],momentum=group['momentum'],lr=group['lr'],dampening=group['dampening'],nesterov=group['nesterov'],maximize=group['maximize'],has_sparse_grad=has_sparse_grad,foreach=group['foreach'])# update momentum_buffers in stateforp,momentum_bufferinzip(params_with_grad,momentum_buffer_list):state=self.state[p]state['momentum_buffer']=momentum_bufferreturnloss
defsgd(params:List[Tensor],d_p_list:List[Tensor],momentum_buffer_list:List[Optional[Tensor]],# kwonly args with defaults are not supported by functions compiled with torchscript issue #70627# setting this as kwarg for now as functional API is compiled by torch/distributed/optimhas_sparse_grad:bool=None,foreach:bool=None,*,weight_decay:float,momentum:float,lr:float,dampening:float,nesterov:bool,maximize:bool):r"""Functional API that performs SGD algorithm computation. See :class:`~torch.optim.SGD` for details. """ifforeachisNone:# Placeholder for more complex foreach logic to be added when value is not setforeach=Falseifforeachandtorch.jit.is_scripting():raiseRuntimeError('torch.jit.script not supported with foreach optimizers')ifforeachandnottorch.jit.is_scripting():func=_multi_tensor_sgdelse:func=_single_tensor_sgdfunc(params,d_p_list,momentum_buffer_list,weight_decay=weight_decay,momentum=momentum,lr=lr,dampening=dampening,nesterov=nesterov,has_sparse_grad=has_sparse_grad,maximize=maximize)def_single_tensor_sgd(params:List[Tensor],d_p_list:List[Tensor],momentum_buffer_list:List[Optional[Tensor]],*,weight_decay:float,momentum:float,lr:float,dampening:float,nesterov:bool,maximize:bool,has_sparse_grad:bool):fori,paraminenumerate(params):d_p=d_p_list[i]ifnotmaximizeelse-d_p_list[i]ifweight_decay!=0:d_p=d_p.add(param,alpha=weight_decay)ifmomentum!=0:buf=momentum_buffer_list[i]ifbufisNone:buf=torch.clone(d_p).detach()momentum_buffer_list[i]=bufelse:buf.mul_(momentum).add_(d_p,alpha=1-dampening)ifnesterov:d_p=d_p.add(buf,alpha=momentum)else:d_p=bufparam.add_(d_p,alpha=-lr)def_multi_tensor_sgd(params:List[Tensor],grads:List[Tensor],momentum_buffer_list:List[Optional[Tensor]],*,weight_decay:float,momentum:float,lr:float,dampening:float,nesterov:bool,maximize:bool,has_sparse_grad:bool):iflen(params)==0:returnifhas_sparse_gradisNone:has_sparse_grad=any(grad.is_sparseforgradingrads)ifmaximize:grads=torch._foreach_neg(tuple(grads))# type: ignore[assignment]ifweight_decay!=0:grads=torch._foreach_add(grads,params,alpha=weight_decay)ifmomentum!=0:bufs=[]all_states_with_momentum_buffer=Trueforiinrange(len(momentum_buffer_list)):ifmomentum_buffer_list[i]isNone:all_states_with_momentum_buffer=Falsebreakelse:bufs.append(momentum_buffer_list[i])ifall_states_with_momentum_buffer:torch._foreach_mul_(bufs,momentum)torch._foreach_add_(bufs,grads,alpha=1-dampening)else:bufs=[]foriinrange(len(momentum_buffer_list)):ifmomentum_buffer_list[i]isNone:buf=momentum_buffer_list[i]=torch.clone(grads[i]).detach()else:buf=momentum_buffer_list[i]buf.mul_(momentum).add_(grads[i],alpha=1-dampening)bufs.append(buf)ifnesterov:torch._foreach_add_(grads,bufs,alpha=momentum)else:grads=bufsifnothas_sparse_grad:torch._foreach_add_(params,grads,alpha=-lr)else:# foreach APIs dont support sparseforiinrange(len(params)):params[i].add_(grads[i],alpha=-lr)
Docs
Access comprehensive developer documentation for PyTorch
To analyze traffic and optimize your experience, we serve cookies on this site. By clicking or navigating, you agree to allow our usage of cookies. As the current maintainers of this site, Facebook’s Cookies Policy applies. Learn more, including about available controls: Cookies Policy.