[docs]classASGD(Optimizer):def__init__(self,params:ParamsT,lr:Union[float,Tensor]=1e-2,lambd:float=1e-4,alpha:float=0.75,t0:float=1e6,weight_decay:float=0,foreach:Optional[bool]=None,maximize:bool=False,differentiable:bool=False,capturable:bool=False,):ifisinstance(lr,Tensor)andlr.numel()!=1:raiseValueError("Tensor lr must be 1-element")ifnot0.0<=lr:raiseValueError(f"Invalid learning rate: {lr}")ifnot0.0<=weight_decay:raiseValueError(f"Invalid weight_decay value: {weight_decay}")defaults=dict(lr=lr,lambd=lambd,alpha=alpha,t0=t0,weight_decay=weight_decay,foreach=foreach,maximize=maximize,differentiable=differentiable,capturable=capturable,)super().__init__(params,defaults)def__setstate__(self,state):super().__setstate__(state)forgroupinself.param_groups:group.setdefault("foreach",None)group.setdefault("maximize",False)group.setdefault("differentiable",False)group.setdefault("capturable",False)forpingroup["params"]:p_state=self.state.get(p,[])iflen(p_state)!=0:ifnottorch.is_tensor(p_state["step"]):step_val=float(p_state["step"])p_state["step"]=torch.tensor(step_val,dtype=_get_scalar_dtype(),device=p.device)ifnottorch.is_tensor(p_state["eta"]):p_state["eta"]=torch.tensor(p_state["eta"],dtype=_get_scalar_dtype(),device=p.device)ifnottorch.is_tensor(p_state["mu"]):p_state["mu"]=torch.tensor(p_state["mu"],dtype=_get_scalar_dtype(),device=p.device)def_init_group(self,group,params_with_grad,grads,mus,axs,etas,state_steps):has_complex=Falseforpingroup["params"]:ifp.gradisnotNone:has_complex|=torch.is_complex(p)params_with_grad.append(p)ifp.grad.is_sparse:raiseRuntimeError("ASGD does not support sparse gradients")grads.append(p.grad)state=self.state[p]# State initializationiflen(state)==0:state["step"]=torch.zeros((),device=p.device,dtype=_get_scalar_dtype())state["eta"]=(torch.as_tensor(group["lr"],device=p.device,dtype=_get_scalar_dtype()).clone().detach())state["mu"]=torch.ones((),device=p.device,dtype=_get_scalar_dtype())state["ax"]=torch.zeros_like(p,memory_format=torch.preserve_format)mus.append(state["mu"])axs.append(state["ax"])etas.append(state["eta"])state_steps.append(state["step"])returnhas_complex
[docs]@_use_grad_for_differentiabledefstep(self,closure=None):"""Perform a single optimization step. Args: closure (Callable, optional): A closure that reevaluates the model and returns the loss. """self._cuda_graph_capture_health_check()loss=NoneifclosureisnotNone:withtorch.enable_grad():loss=closure()forgroupinself.param_groups:params_with_grad:List[Tensor]=[]grads:List[Tensor]=[]mus:List[Tensor]=[]axs:List[Tensor]=[]etas:List[Tensor]=[]state_steps:List[Tensor]=[]has_complex=self._init_group(group,params_with_grad,grads,mus,axs,etas,state_steps)asgd(params_with_grad,grads,axs,mus,etas,state_steps,lambd=group["lambd"],lr=group["lr"],t0=group["t0"],alpha=group["alpha"],weight_decay=group["weight_decay"],foreach=group["foreach"],maximize=group["maximize"],differentiable=group["differentiable"],capturable=group["capturable"],has_complex=has_complex,)returnloss
ASGD.__doc__=rf"""Implements Averaged Stochastic Gradient Descent. It has been proposed in `Acceleration of stochastic approximation by averaging`_. Args:{_params_doc} lr (float, Tensor, optional): learning rate (default: 1e-2) lambd (float, optional): decay term (default: 1e-4) alpha (float, optional): power for eta update (default: 0.75) t0 (float, optional): point at which to start averaging (default: 1e6) weight_decay (float, optional): weight decay (L2 penalty) (default: 0){_foreach_doc}{_maximize_doc}{_differentiable_doc}{_capturable_doc} .. _Acceleration of stochastic approximation by averaging: https://dl.acm.org/citation.cfm?id=131098 """def_single_tensor_asgd(params:List[Tensor],grads:List[Tensor],axs:List[Tensor],mus:List[Tensor],etas:List[Tensor],state_steps:List[Tensor],*,lambd:float,lr:float,t0:float,alpha:float,weight_decay:float,maximize:bool,differentiable:bool,capturable:bool,has_complex:bool,):fori,paraminenumerate(params):grad=grads[i]grad=gradifnotmaximizeelse-gradmu=mus[i]ax=axs[i]eta=etas[i]step_t=state_steps[i]# If compiling, the compiler will handle cudagraph checks, see note [torch.compile x capturable]ifnottorch.compiler.is_compiling()andcapturable:capturable_supported_devices=_get_capturable_supported_devices()assert(param.device.type==mu.device.type==eta.device.type==step_t.device.typeandparam.device.typeincapturable_supported_devices),(f"If capturable=True, params, mus, etas, and state_steps must be "f"on supported devices: {capturable_supported_devices}.")iftorch.is_complex(param):grad=torch.view_as_real(grad)param=torch.view_as_real(param)ax=torch.view_as_real(ax)# update stepstep_t+=1ifweight_decay!=0:grad=grad.add(param,alpha=weight_decay)ifcapturable:param.mul_(1-lambd*eta)param.addcmul_(grad,eta,value=-1)# update parameterelse:eta_value=_get_value(eta)param.mul_(1-lambd*eta_value)# decay termparam.add_(grad,alpha=-eta_value)# update parameter# averagingifcapturableormu.item()!=1:ax.add_(param.sub(ax).mul_(mu))else:ax.copy_(param)ifcapturable:eta.copy_(lr/((1+lambd*lr*step_t)**alpha))mu.copy_(1/torch.maximum(step_t-t0,torch.ones_like(step_t)))else:step=_get_value(step_t)new_eta=torch.as_tensor(lr/((1+lambd*lr*step)**alpha))eta.copy_(new_eta)new_mu=torch.as_tensor(1/max(1,step-t0))mu.copy_(new_mu)def_multi_tensor_asgd(params:List[Tensor],grads:List[Tensor],axs:List[Tensor],mus:List[Tensor],etas:List[Tensor],state_steps:List[Tensor],*,lambd:float,lr:float,t0:float,alpha:float,weight_decay:float,maximize:bool,differentiable:bool,capturable:bool,has_complex:bool,):iflen(params)==0:returnassertnotdifferentiable,"_foreach ops don't support autograd"# If compiling, the compiler will handle cudagraph checks, see note [torch.compile x capturable]ifnottorch.compiler.is_compiling()andcapturable:capturable_supported_devices=_get_capturable_supported_devices(supports_xla=False)assertall(p.device.type==mu.device.type==eta.device.type==step.device.typeandp.device.typeincapturable_supported_devicesforp,mu,eta,stepinzip(params,mus,etas,state_steps)),f"If capturable=True, params, mus, etas, and state_steps must be on supported devices: {capturable_supported_devices}."grouped_tensors=Optimizer._group_tensors_by_device_and_dtype([params,grads,axs,mus,etas,state_steps]# type: ignore[list-item])for(device,_),((grouped_params_,grouped_grads_,grouped_axs_,grouped_mus_,grouped_etas_,grouped_state_steps_,),_,)ingrouped_tensors.items():grouped_params=cast(List[Tensor],grouped_params_)grouped_grads=cast(List[Tensor],grouped_grads_)grouped_axs=cast(List[Tensor],grouped_axs_)grouped_mus=cast(List[Tensor],grouped_mus_)grouped_etas=cast(List[Tensor],grouped_etas_)grouped_state_steps=cast(List[Tensor],grouped_state_steps_)ifhas_complex:_view_as_real(grouped_params,grouped_grads,grouped_axs)ifmaximize:grouped_grads=torch._foreach_neg(grouped_grads)# type: ignore[assignment]# Update steps# If steps are on CPU, foreach will fall back to the slow path, which is a for-loop calling t.add(1) over# and over. 1 will then be wrapped into a Tensor over and over again, which is slower than if we just# wrapped it once now. The alpha is required to assure we go to the right overload.ifnottorch.compiler.is_compiling()andgrouped_state_steps[0].is_cpu:torch._foreach_add_(grouped_state_steps,torch.tensor(1.0,device="cpu"),alpha=1.0)else:torch._foreach_add_(grouped_state_steps,1)# intermediate = grad + param * lambdintermediate:Union[Tuple[Tensor,...],List[Tensor]]ifweight_decay!=0:ifmaximize:torch._foreach_add_(grouped_grads,grouped_params,alpha=weight_decay)intermediate=grouped_gradselse:intermediate=torch._foreach_add(grouped_grads,grouped_params,alpha=weight_decay)torch._foreach_add_(intermediate,grouped_params,alpha=lambd)else:intermediate=torch._foreach_add(grouped_grads,grouped_params,alpha=lambd)# update param# param * (1 - lambd * eta) - eta * grad# => param - param * lambd * eta - eta * grad# => param - eta * intermediatetorch._foreach_addcmul_(grouped_params,intermediate,grouped_etas,value=-1)delintermediate# update grouped_axs# averaging: ax = ax + mu * (param - ax)# Note (mlazos): We can't use lerp here since it requires weight to be float64# and our grouping code requires dtypes to match for all tensors in a group (and it should, since# we use the mus in other places)# all dtypes need to match, so we could introduce a cast in a loop# but since this only adds one additional kernel launch, this looks like the cleaner# and faster solutionintermediate=torch._foreach_sub(grouped_params,grouped_axs)torch._foreach_addcmul_(grouped_axs,intermediate,grouped_mus)delintermediatenew_etas:Union[Tuple[Tensor,...],List[Tensor]]new_mus:Union[Tuple[Tensor,...],List[Tensor]]ifcapturable:# update grouped_musnew_mus=torch._foreach_sub(grouped_state_steps,t0)torch._foreach_maximum_(new_mus,1.0)torch._foreach_reciprocal_(new_mus)torch._foreach_copy_(grouped_mus,new_mus)delnew_mus# update eta = lr / ((1 + lambd * lr * step)^alpha)new_etas=torch._foreach_mul(grouped_state_steps,lambd)torch._foreach_mul_(new_etas,lr)torch._foreach_add_(new_etas,1)torch._foreach_pow_(new_etas,alpha)torch._foreach_reciprocal_(new_etas)torch._foreach_mul_(new_etas,lr)torch._foreach_copy_(grouped_etas,new_etas)else:new_etas=[torch.as_tensor(lr/((1+lambd*lr*step)**alpha),device=device)forstepingrouped_state_steps]new_mus=[torch.as_tensor(1/max(1,_get_value(step)-t0),device=device)forstepingrouped_state_steps]torch._foreach_copy_(grouped_etas,new_etas)torch._foreach_copy_(grouped_mus,new_mus)@_disable_dynamo_if_unsupported(single_tensor_fn=_single_tensor_asgd)defasgd(params:List[Tensor],grads:List[Tensor],axs:List[Tensor],mus:List[Tensor],etas:List[Tensor],state_steps:List[Tensor],# kwonly args with defaults are not supported by functions compiled with torchscript issue #70627# setting this as kwarg for now as functional API is compiled by torch/distributed/optimforeach:Optional[bool]=None,maximize:bool=False,differentiable:bool=False,capturable:bool=False,has_complex:bool=False,*,lambd:float,lr:float,t0:float,alpha:float,weight_decay:float,):r"""Functional API that performs asgd algorithm computation. See :class:`~torch.optim.ASGD` for details. """ifforeachisNone:_,foreach=_default_to_fused_or_foreach(params,differentiable,use_fused=False)ifforeachandtorch.jit.is_scripting():raiseRuntimeError("torch.jit.script not supported with foreach optimizers")ifforeachandnottorch.jit.is_scripting():func=_multi_tensor_asgdelse:func=_single_tensor_asgdfunc(params,grads,axs,mus,etas,state_steps,lambd=lambd,lr=lr,t0=t0,alpha=alpha,weight_decay=weight_decay,maximize=maximize,differentiable=differentiable,capturable=capturable,has_complex=has_complex,)
Docs
Access comprehensive developer documentation for PyTorch
To analyze traffic and optimize your experience, we serve cookies on this site. By clicking or navigating, you agree to allow our usage of cookies. As the current maintainers of this site, Facebook’s Cookies Policy applies. Learn more, including about available controls: Cookies Policy.