[docs]classASGD(Optimizer):"""Implements Averaged Stochastic Gradient Descent. It has been proposed in `Acceleration of stochastic approximation by averaging`_. Args: params (iterable): iterable of parameters to optimize or dicts defining parameter groups lr (float, optional): learning rate (default: 1e-2) lambd (float, optional): decay term (default: 1e-4) alpha (float, optional): power for eta update (default: 0.75) t0 (float, optional): point at which to start averaging (default: 1e6) weight_decay (float, optional): weight decay (L2 penalty) (default: 0) .. _Acceleration of stochastic approximation by averaging: https://dl.acm.org/citation.cfm?id=131098 """def__init__(self,params,lr=1e-2,lambd=1e-4,alpha=0.75,t0=1e6,weight_decay=0):ifnot0.0<=lr:raiseValueError("Invalid learning rate: {}".format(lr))ifnot0.0<=weight_decay:raiseValueError("Invalid weight_decay value: {}".format(weight_decay))defaults=dict(lr=lr,lambd=lambd,alpha=alpha,t0=t0,weight_decay=weight_decay)super(ASGD,self).__init__(params,defaults)
[docs]@torch.no_grad()defstep(self,closure=None):"""Performs a single optimization step. Args: closure (callable, optional): A closure that reevaluates the model and returns the loss. """loss=NoneifclosureisnotNone:withtorch.enable_grad():loss=closure()forgroupinself.param_groups:params_with_grad=[]grads=[]mus=[]axs=[]etas=[]state_steps=[]forpingroup['params']:ifp.gradisnotNone:params_with_grad.append(p)ifp.grad.is_sparse:raiseRuntimeError('ASGD does not support sparse gradients')grads.append(p.grad)state=self.state[p]# State initializationiflen(state)==0:state['step']=0state['eta']=group['lr']state['mu']=1state['ax']=torch.zeros_like(p,memory_format=torch.preserve_format)mus.append(state['mu'])axs.append(state['ax'])etas.append(state['eta'])state['step']+=1state_steps.append(state['step'])F.asgd(params_with_grad,grads,axs,mus,etas,weight_decay=group['weight_decay'],lambd=group['lambd'])# update eta and muforp,mu,etainzip(params_with_grad,mus,etas):state=self.state[p]state['eta']=(group['lr']/math.pow((1+group['lambd']*group['lr']*state['step']),group['alpha']))state['mu']=1/max(1,state['step']-group['t0'])returnloss
Docs
Access comprehensive developer documentation for PyTorch
To analyze traffic and optimize your experience, we serve cookies on this site. By clicking or navigating, you agree to allow our usage of cookies. As the current maintainers of this site, Facebook’s Cookies Policy applies. Learn more, including about available controls: Cookies Policy.