# mypy: allow-untyped-defsr"""Implementation for Stochastic Gradient Descent optimizer."""fromtypingimportcast,List,Optional,UnionimporttorchfromtorchimportTensorfrom.optimizerimport(_default_to_fused_or_foreach,_device_dtype_check_for_fused,_differentiable_doc,_foreach_doc,_fused_doc,_maximize_doc,_params_doc,_use_grad_for_differentiable,DeviceDict,Optimizer,ParamsT,)__all__=["SGD","sgd"]
[docs]classSGD(Optimizer):# noqa: D101def__init__(self,params:ParamsT,lr:Union[float,Tensor]=1e-3,momentum:float=0,dampening:float=0,weight_decay:float=0,nesterov:bool=False,*,maximize:bool=False,foreach:Optional[bool]=None,differentiable:bool=False,fused:Optional[bool]=None,):# noqa: D107ifisinstance(lr,Tensor)andlr.numel()!=1:raiseValueError("Tensor lr must be 1-element")iflr<0.0:raiseValueError(f"Invalid learning rate: {lr}")ifmomentum<0.0:raiseValueError(f"Invalid momentum value: {momentum}")ifweight_decay<0.0:raiseValueError(f"Invalid weight_decay value: {weight_decay}")defaults=dict(lr=lr,momentum=momentum,dampening=dampening,weight_decay=weight_decay,nesterov=nesterov,maximize=maximize,foreach=foreach,differentiable=differentiable,fused=fused,)ifnesterovand(momentum<=0ordampening!=0):raiseValueError("Nesterov momentum requires a momentum and zero dampening")super().__init__(params,defaults)iffused:self._step_supports_amp_scaling=Trueself._need_device_dtype_check_for_fused=Trueifdifferentiable:raiseRuntimeError("`fused` does not support `differentiable`")ifforeach:raiseRuntimeError("`fused` and `foreach` cannot be `True` together.")def__setstate__(self,state):# noqa: D105super().__setstate__(state)forgroupinself.param_groups:group.setdefault("nesterov",False)group.setdefault("maximize",False)group.setdefault("foreach",None)group.setdefault("differentiable",False)group.setdefault("fused",False)def_init_group(self,group,params,grads,momentum_buffer_list):has_sparse_grad=Falseforpingroup["params"]:ifp.gradisnotNone:ifgroup["fused"]andgetattr(self,"_need_device_dtype_check_for_fused",True):_device_dtype_check_for_fused(p)self._need_device_dtype_check_for_fused=Falseparams.append(p)grads.append(p.grad)ifp.grad.is_sparse:has_sparse_grad=Trueifgroup["momentum"]!=0:state=self.state[p]momentum_buffer_list.append(state.get("momentum_buffer"))returnhas_sparse_grad
[docs]@_use_grad_for_differentiabledefstep(self,closure=None):"""Perform a single optimization step. Args: closure (Callable, optional): A closure that reevaluates the model and returns the loss. """loss=NoneifclosureisnotNone:withtorch.enable_grad():loss=closure()forgroupinself.param_groups:params:List[Tensor]=[]grads:List[Tensor]=[]momentum_buffer_list:List[Optional[Tensor]]=[]has_sparse_grad=self._init_group(group,params,grads,momentum_buffer_list)sgd(params,grads,momentum_buffer_list,weight_decay=group["weight_decay"],momentum=group["momentum"],lr=group["lr"],dampening=group["dampening"],nesterov=group["nesterov"],maximize=group["maximize"],has_sparse_grad=has_sparse_grad,foreach=group["foreach"],fused=group["fused"],grad_scale=getattr(self,"grad_scale",None),found_inf=getattr(self,"found_inf",None),)ifgroup["momentum"]!=0:# update momentum_buffers in stateforp,momentum_bufferinzip(params,momentum_buffer_list):state=self.state[p]state["momentum_buffer"]=momentum_bufferreturnloss
SGD.__doc__=(r"""Implements stochastic gradient descent (optionally with momentum). .. math:: \begin{aligned} &\rule{110mm}{0.4pt} \\ &\textbf{input} : \gamma \text{ (lr)}, \: \theta_0 \text{ (params)}, \: f(\theta) \text{ (objective)}, \: \lambda \text{ (weight decay)}, \\ &\hspace{13mm} \:\mu \text{ (momentum)}, \:\tau \text{ (dampening)}, \:\textit{ nesterov,}\:\textit{ maximize} \\[-1.ex] &\rule{110mm}{0.4pt} \\ &\textbf{for} \: t=1 \: \textbf{to} \: \ldots \: \textbf{do} \\ &\hspace{5mm}g_t \leftarrow \nabla_{\theta} f_t (\theta_{t-1}) \\ &\hspace{5mm}\textbf{if} \: \lambda \neq 0 \\ &\hspace{10mm} g_t \leftarrow g_t + \lambda \theta_{t-1} \\ &\hspace{5mm}\textbf{if} \: \mu \neq 0 \\ &\hspace{10mm}\textbf{if} \: t > 1 \\ &\hspace{15mm} \textbf{b}_t \leftarrow \mu \textbf{b}_{t-1} + (1-\tau) g_t \\ &\hspace{10mm}\textbf{else} \\ &\hspace{15mm} \textbf{b}_t \leftarrow g_t \\ &\hspace{10mm}\textbf{if} \: \textit{nesterov} \\ &\hspace{15mm} g_t \leftarrow g_{t} + \mu \textbf{b}_t \\ &\hspace{10mm}\textbf{else} \\[-1.ex] &\hspace{15mm} g_t \leftarrow \textbf{b}_t \\ &\hspace{5mm}\textbf{if} \: \textit{maximize} \\ &\hspace{10mm}\theta_t \leftarrow \theta_{t-1} + \gamma g_t \\[-1.ex] &\hspace{5mm}\textbf{else} \\[-1.ex] &\hspace{10mm}\theta_t \leftarrow \theta_{t-1} - \gamma g_t \\[-1.ex] &\rule{110mm}{0.4pt} \\[-1.ex] &\bf{return} \: \theta_t \\[-1.ex] &\rule{110mm}{0.4pt} \\[-1.ex] \end{aligned} Nesterov momentum is based on the formula from `On the importance of initialization and momentum in deep learning`__. """+rf""" Args:{_params_doc} lr (float, Tensor, optional): learning rate (default: 1e-3) momentum (float, optional): momentum factor (default: 0) dampening (float, optional): dampening for momentum (default: 0) weight_decay (float, optional): weight decay (L2 penalty) (default: 0) nesterov (bool, optional): enables Nesterov momentum. Only applicable when momentum is non-zero. (default: False){_maximize_doc}{_foreach_doc}{_differentiable_doc}{_fused_doc} """+r""" Example: >>> # xdoctest: +SKIP >>> optimizer = torch.optim.SGD(model.parameters(), lr=0.1, momentum=0.9) >>> optimizer.zero_grad() >>> loss_fn(model(input), target).backward() >>> optimizer.step() __ http://www.cs.toronto.edu/%7Ehinton/absps/momentum.pdf .. note:: The implementation of SGD with Momentum/Nesterov subtly differs from Sutskever et al. and implementations in some other frameworks. Considering the specific case of Momentum, the update can be written as .. math:: \begin{aligned} v_{t+1} & = \mu * v_{t} + g_{t+1}, \\ p_{t+1} & = p_{t} - \text{lr} * v_{t+1}, \end{aligned} where :math:`p`, :math:`g`, :math:`v` and :math:`\mu` denote the parameters, gradient, velocity, and momentum respectively. This is in contrast to Sutskever et al. and other frameworks which employ an update of the form .. math:: \begin{aligned} v_{t+1} & = \mu * v_{t} + \text{lr} * g_{t+1}, \\ p_{t+1} & = p_{t} - v_{t+1}. \end{aligned} The Nesterov version is analogously modified. Moreover, the initial value of the momentum buffer is set to the gradient value at the first step. This is in contrast to some other frameworks that initialize it to all zeros. """)defsgd(params:List[Tensor],d_p_list:List[Tensor],momentum_buffer_list:List[Optional[Tensor]],# kwonly args with defaults are not supported by functions compiled with torchscript issue #70627# setting this as kwarg for now as functional API is compiled by torch/distributed/optimhas_sparse_grad:bool=False,foreach:Optional[bool]=None,fused:Optional[bool]=None,grad_scale:Optional[Tensor]=None,found_inf:Optional[Tensor]=None,*,weight_decay:float,momentum:float,lr:float,dampening:float,nesterov:bool,maximize:bool,):r"""Functional API that performs SGD algorithm computation. See :class:`~torch.optim.SGD` for details. """# Respect when the user inputs False/True for foreach or fused. We only want to change# the default when neither have been user-specified. Note that we default to foreach# and pass False to use_fused. This is not a mistake--we want to give the fused impl# bake-in time before making it the default, even if it is typically faster.ifforeachisNoneandfusedisNone:# why must we be explicit about an if statement for torch.jit.is_scripting here?# because JIT can't handle Optionals nor fancy conditionals when scriptingifnottorch.jit.is_scripting():fused,foreach=_default_to_fused_or_foreach(params,differentiable=False,use_fused=False)else:foreach=Falsefused=FalseifforeachisNone:foreach=FalseiffusedisNone:fused=Falseifforeachandtorch.jit.is_scripting():raiseRuntimeError("torch.jit.script not supported with foreach optimizers")iffusedandtorch.jit.is_scripting():raiseRuntimeError("torch.jit.script not supported with fused optimizers")ifforeachandnottorch.jit.is_scripting():func=_multi_tensor_sgdeliffusedandnottorch.jit.is_scripting():func=_fused_sgdelse:func=_single_tensor_sgdfunc(params,d_p_list,momentum_buffer_list,weight_decay=weight_decay,momentum=momentum,lr=lr,dampening=dampening,nesterov=nesterov,has_sparse_grad=has_sparse_grad,maximize=maximize,grad_scale=grad_scale,found_inf=found_inf,)def_single_tensor_sgd(params:List[Tensor],grads:List[Tensor],momentum_buffer_list:List[Optional[Tensor]],grad_scale:Optional[Tensor],found_inf:Optional[Tensor],*,weight_decay:float,momentum:float,lr:float,dampening:float,nesterov:bool,maximize:bool,has_sparse_grad:bool,):assertgrad_scaleisNoneandfound_infisNonefori,paraminenumerate(params):grad=grads[i]ifnotmaximizeelse-grads[i]ifweight_decay!=0:grad=grad.add(param,alpha=weight_decay)ifmomentum!=0:buf=momentum_buffer_list[i]ifbufisNone:buf=torch.clone(grad).detach()momentum_buffer_list[i]=bufelse:buf.mul_(momentum).add_(grad,alpha=1-dampening)ifnesterov:grad=grad.add(buf,alpha=momentum)else:grad=bufparam.add_(grad,alpha=-lr)def_multi_tensor_sgd(params:List[Tensor],grads:List[Tensor],momentum_buffer_list:List[Optional[Tensor]],grad_scale:Optional[Tensor],found_inf:Optional[Tensor],*,weight_decay:float,momentum:float,lr:float,dampening:float,nesterov:bool,maximize:bool,has_sparse_grad:bool,):assertgrad_scaleisNoneandfound_infisNoneiflen(params)==0:returngrouped_tensors=Optimizer._group_tensors_by_device_and_dtype([params,grads,momentum_buffer_list],with_indices=True# type: ignore[list-item])for(device_params_,device_grads_,device_momentum_buffer_list,),indicesingrouped_tensors.values():device_params:List[Tensor]=cast(List[Tensor],device_params_)device_grads:List[Tensor]=cast(List[Tensor],device_grads_)device_has_sparse_grad=has_sparse_gradandany(grad.is_sparseforgradindevice_grads)ifmaximize:device_grads=torch._foreach_neg(device_grads)# type: ignore[assignment]ifweight_decay!=0:# Re-use the intermediate memory (device_grads) already allocated for maximizeifmaximize:torch._foreach_add_(device_grads,device_params,alpha=weight_decay)else:device_grads=torch._foreach_add(# type: ignore[assignment]device_grads,device_params,alpha=weight_decay)ifmomentum!=0:bufs:List[Tensor]=[]all_states_with_momentum_buffer=Trueforiinrange(len(device_momentum_buffer_list)):ifdevice_momentum_buffer_list[i]isNone:all_states_with_momentum_buffer=Falsebreakelse:bufs.append(cast(Tensor,device_momentum_buffer_list[i]))ifall_states_with_momentum_buffer:torch._foreach_mul_(bufs,momentum)torch._foreach_add_(bufs,device_grads,alpha=1-dampening)else:bufs=[]foriinrange(len(device_momentum_buffer_list)):ifdevice_momentum_buffer_list[i]isNone:buf=device_momentum_buffer_list[i]=momentum_buffer_list[indices[i]]=torch.clone(device_grads[i]).detach()else:buf=cast(Tensor,device_momentum_buffer_list[i])buf.mul_(momentum).add_(device_grads[i],alpha=1-dampening)bufs.append(buf)ifnesterov:torch._foreach_add_(device_grads,bufs,alpha=momentum)else:device_grads=bufsifnotdevice_has_sparse_grad:# handle internal item() call if lr is a tensorifisinstance(lr,torch.Tensor)andtorch.compiler.is_compiling():grads_x_lr=torch._foreach_mul(device_grads,-lr)torch._foreach_add_(device_params,grads_x_lr)else:torch._foreach_add_(device_params,device_grads,alpha=-lr)else:# foreach APIs don't support sparseforiinrange(len(device_params)):device_params[i].add_(device_grads[i],alpha=-lr)def_fused_sgd(params:List[Tensor],grads:List[Tensor],momentum_buffer_list:List[Optional[Tensor]],grad_scale:Optional[Tensor],found_inf:Optional[Tensor],*,weight_decay:float,momentum:float,lr:float,dampening:float,nesterov:bool,maximize:bool,has_sparse_grad:bool,)->None:ifnotparams:returnifhas_sparse_grad:raiseRuntimeError("`_fused_sgd` does not support sparse gradients")grad_scale_dict:DeviceDict=({grad_scale.device:grad_scale}ifgrad_scaleisnotNoneelse{})found_inf_dict:DeviceDict=({found_inf.device:found_inf}iffound_infisnotNoneelse{})no_momentum_buffer=momentum==0is_first_step=(all(tisNonefortinmomentum_buffer_list)andnotno_momentum_buffer)ifis_first_step:fori,ginenumerate(grads):momentum_buffer_list[i]=torch.empty_like(g)grouped_tensors=Optimizer._group_tensors_by_device_and_dtype([params,grads,momentum_buffer_list],with_indices=False# type: ignore[list-item])for(device,_),((device_params_,device_grads_,device_momentum_buffer_list),_,)ingrouped_tensors.items():device_params:List[Tensor]=cast(List[Tensor],device_params_)device_grads:List[Tensor]=cast(List[Tensor],device_grads_)device_grad_scale,device_found_inf=None,Noneifgrad_scaleisnotNone:device_grad_scale=grad_scale_dict.setdefault(device,grad_scale.to(device))iffound_inf_dictisnotNoneandfound_infisnotNone:device_found_inf=found_inf_dict.setdefault(device,found_inf.to(device))torch._fused_sgd_(device_params,device_grads,[]ifno_momentum_bufferelsecast(List[Tensor],device_momentum_buffer_list),weight_decay=weight_decay,momentum=momentum,lr=lr,dampening=dampening,nesterov=nesterov,maximize=maximize,is_first_step=is_first_step,grad_scale=device_grad_scale,found_inf=device_found_inf,)
Docs
Access comprehensive developer documentation for PyTorch
To analyze traffic and optimize your experience, we serve cookies on this site. By clicking or navigating, you agree to allow our usage of cookies. As the current maintainers of this site, Facebook’s Cookies Policy applies. Learn more, including about available controls: Cookies Policy.