[docs]classASGD(Optimizer):def__init__(self,params,lr=1e-2,lambd=1e-4,alpha=0.75,t0=1e6,weight_decay=0,foreach:Optional[bool]=None,maximize:bool=False,differentiable:bool=False,):ifnot0.0<=lr:raiseValueError(f"Invalid learning rate: {lr}")ifnot0.0<=weight_decay:raiseValueError(f"Invalid weight_decay value: {weight_decay}")defaults=dict(lr=lr,lambd=lambd,alpha=alpha,t0=t0,weight_decay=weight_decay,foreach=foreach,maximize=maximize,differentiable=differentiable,)super().__init__(params,defaults)def__setstate__(self,state):super().__setstate__(state)forgroupinself.param_groups:group.setdefault("foreach",None)group.setdefault("maximize",False)group.setdefault("differentiable",False)state_values=list(self.state.values())step_is_tensor=(len(state_values)!=0)andtorch.is_tensor(state_values[0]["step"])ifnotstep_is_tensor:forsinstate_values:s["step"]=torch.tensor(float(s["step"]))eta_is_tensor=(len(state_values)!=0)andtorch.is_tensor(state_values[0]["eta"])ifnoteta_is_tensor:forsinstate_values:s["eta"]=torch.tensor(s["eta"])mu_is_tensor=(len(state_values)!=0)andtorch.is_tensor(state_values[0]["mu"])ifnotmu_is_tensor:forsinstate_values:s["mu"]=torch.tensor(float(s["mu"]))def_init_group(self,group,params_with_grad,grads,mus,axs,etas,state_steps):forpingroup["params"]:ifp.gradisnotNone:params_with_grad.append(p)ifp.grad.is_sparse:raiseRuntimeError("ASGD does not support sparse gradients")grads.append(p.grad)state=self.state[p]# State initializationiflen(state)==0:state["step"]=torch.tensor(0.0)state["eta"]=torch.tensor(group["lr"])state["mu"]=torch.tensor(1.0)state["ax"]=torch.zeros_like(p,memory_format=torch.preserve_format)mus.append(state["mu"])axs.append(state["ax"])etas.append(state["eta"])state_steps.append(state["step"])
[docs]@_use_grad_for_differentiabledefstep(self,closure=None):"""Performs a single optimization step. Args: closure (Callable, optional): A closure that reevaluates the model and returns the loss. """loss=NoneifclosureisnotNone:withtorch.enable_grad():loss=closure()forgroupinself.param_groups:params_with_grad=[]grads=[]mus=[]axs=[]etas=[]state_steps=[]self._init_group(group,params_with_grad,grads,mus,axs,etas,state_steps)asgd(params_with_grad,grads,axs,mus,etas,state_steps,lambd=group["lambd"],lr=group["lr"],t0=group["t0"],alpha=group["alpha"],weight_decay=group["weight_decay"],foreach=group["foreach"],maximize=group["maximize"],differentiable=group["differentiable"],)returnloss
ASGD.__doc__=fr"""Implements Averaged Stochastic Gradient Descent. It has been proposed in `Acceleration of stochastic approximation by averaging`_. Args: params (iterable): iterable of parameters to optimize or dicts defining parameter groups lr (float, optional): learning rate (default: 1e-2) lambd (float, optional): decay term (default: 1e-4) alpha (float, optional): power for eta update (default: 0.75) t0 (float, optional): point at which to start averaging (default: 1e6) weight_decay (float, optional): weight decay (L2 penalty) (default: 0){_foreach_doc}{_maximize_doc}{_differentiable_doc} .. _Acceleration of stochastic approximation by averaging: https://dl.acm.org/citation.cfm?id=131098 """defasgd(params:List[Tensor],grads:List[Tensor],axs:List[Tensor],mus:List[Tensor],etas:List[Tensor],state_steps:List[Tensor],# kwonly args with defaults are not supported by functions compiled with torchscript issue #70627# setting this as kwarg for now as functional API is compiled by torch/distributed/optimforeach:Optional[bool]=None,maximize:bool=False,differentiable:bool=False,*,lambd:float,lr:float,t0:float,alpha:float,weight_decay:float,):r"""Functional API that performs asgd algorithm computation. See :class:`~torch.optim.ASGD` for details. """ifforeachisNone:_,foreach=_default_to_fused_or_foreach(params,differentiable,use_fused=False)ifforeachandtorch.jit.is_scripting():raiseRuntimeError("torch.jit.script not supported with foreach optimizers")ifforeachandnottorch.jit.is_scripting():func=_multi_tensor_asgdelse:func=_single_tensor_asgdfunc(params,grads,axs,mus,etas,state_steps,lambd=lambd,lr=lr,t0=t0,alpha=alpha,weight_decay=weight_decay,maximize=maximize,differentiable=differentiable,)def_single_tensor_asgd(params:List[Tensor],grads:List[Tensor],axs:List[Tensor],mus:List[Tensor],etas:List[Tensor],state_steps:List[Tensor],*,lambd:float,lr:float,t0:float,alpha:float,weight_decay:float,maximize:bool,differentiable:bool,):def_to_tensor(x):ifnotisinstance(x,torch.Tensor):returntorch.tensor(x)returnxfori,paraminenumerate(params):grad=grads[i]grad=gradifnotmaximizeelse-gradmu=mus[i]ax=axs[i]eta=etas[i]step_t=state_steps[i]iftorch.is_complex(param):grad=torch.view_as_real(grad)param=torch.view_as_real(param)ax=torch.view_as_real(ax)# update stepstep_t+=1step=_get_value(step_t)ifweight_decay!=0:grad=grad.add(param,alpha=weight_decay)eta_value=_get_value(eta)# decay termparam.mul_(1-lambd*eta_value)# update parameterparam.add_(grad,alpha=-eta_value)# averagingifis_compiling()ormu.item()!=1:ax.add_(param.sub(ax).mul(mu))else:ax.copy_(param)new_eta=_to_tensor(lr/((1+lambd*lr*step)**alpha))eta.copy_(new_eta)new_mu=_to_tensor(1/max(1,step-t0))mu.copy_(new_mu)def_multi_tensor_asgd(params:List[Tensor],grads:List[Tensor],axs:List[Tensor],mus:List[Tensor],etas:List[Tensor],state_steps:List[Tensor],*,lambd:float,lr:float,t0:float,alpha:float,weight_decay:float,maximize:bool,differentiable:bool,):iflen(params)==0:returnassertnotdifferentiable,"_foreach ops don't support autograd"grouped_tensors=Optimizer._group_tensors_by_device_and_dtype([params,grads,axs,mus,etas,state_steps])for((grouped_params,grouped_grads,grouped_axs,grouped_mus,grouped_etas,grouped_state_steps),_)ingrouped_tensors.values():ifmaximize:grouped_grads=torch._foreach_neg(grouped_grads)def_view_complex_as_real(tensor_list):return[torch.view_as_real(t)iftorch.is_complex(t)elsetfortintensor_list]grouped_grads=_view_complex_as_real(grouped_grads)grouped_params=_view_complex_as_real(grouped_params)grouped_axs=_view_complex_as_real(grouped_axs)# update steptorch._foreach_add_(grouped_state_steps,1)ifweight_decay!=0:# Re-use the intermediate memory (grouped_grads) already allocated for maximizeifmaximize:torch._foreach_add_(grouped_grads,grouped_params,alpha=weight_decay)else:grouped_grads=torch._foreach_add(grouped_grads,grouped_params,alpha=weight_decay)# decay termeta=_get_value(grouped_etas[0])torch._foreach_mul_(grouped_params,1-lambd*eta)# update parametertorch._foreach_add_(grouped_params,grouped_grads,alpha=-eta)# averagingforiinrange(len(grouped_axs)):ifis_compiling()orgrouped_mus[i].item()!=1:grouped_axs[i].add_(grouped_params[i].sub(grouped_axs[i]).mul(grouped_mus[i]))else:grouped_axs[i].copy_(grouped_params[i])# update eta and muforiinrange(len(grouped_mus)):new_eta=_to_tensor(lr/(1+lambd*lr*_get_value(grouped_state_steps[i])**alpha))grouped_etas[i].copy_(new_eta)new_mu=_to_tensor(1/max(1,_get_value(grouped_state_steps[i])-t0))grouped_mus[i].copy_(new_mu)
Docs
Access comprehensive developer documentation for PyTorch
To analyze traffic and optimize your experience, we serve cookies on this site. By clicking or navigating, you agree to allow our usage of cookies. As the current maintainers of this site, Facebook’s Cookies Policy applies. Learn more, including about available controls: Cookies Policy.