[docs]classAdam(Optimizer):def__init__(self,params:ParamsT,lr:Union[float,Tensor]=1e-3,betas:Tuple[Union[float,Tensor],Union[float,Tensor]]=(0.9,0.999),eps:float=1e-8,weight_decay:float=0,amsgrad:bool=False,*,foreach:Optional[bool]=None,maximize:bool=False,capturable:bool=False,differentiable:bool=False,fused:Optional[bool]=None,):ifisinstance(lr,Tensor):ifforeachandnotcapturable:raiseValueError("lr as a Tensor is not supported for capturable=False and foreach=True")iflr.numel()!=1:raiseValueError("Tensor lr must be 1-element")ifnot0.0<=lr:raiseValueError(f"Invalid learning rate: {lr}")ifnot0.0<=eps:raiseValueError(f"Invalid epsilon value: {eps}")ifnot0.0<=betas[0]<1.0:raiseValueError(f"Invalid beta parameter at index 0: {betas[0]}")ifnot0.0<=betas[1]<1.0:raiseValueError(f"Invalid beta parameter at index 1: {betas[1]}")ifnot0.0<=weight_decay:raiseValueError(f"Invalid weight_decay value: {weight_decay}")ifnot((isinstance(betas[0],float)andisinstance(betas[1],float))or(isinstance(betas[0],Tensor)andisinstance(betas[1],Tensor))):raiseValueError("betas must be either both floats or both Tensors")ifisinstance(betas[0],Tensor):ifnotcapturableandforeach:raiseValueError("betas[0] as a Tensor is not supported for capturable=False and foreach=True")ifbetas[0].numel()!=1:raiseValueError("Tensor betas[0] must be 1-element")ifisinstance(betas[1],Tensor):ifnotcapturableandforeach:raiseValueError("betas[1] as a Tensor is not supported for capturable=False and foreach=True")ifbetas[1].numel()!=1:raiseValueError("Tensor betas[1] must be 1-element")defaults=dict(lr=lr,betas=betas,eps=eps,weight_decay=weight_decay,amsgrad=amsgrad,maximize=maximize,foreach=foreach,capturable=capturable,differentiable=differentiable,fused=fused,)super().__init__(params,defaults)iffused:ifdifferentiable:raiseRuntimeError("`fused` does not support `differentiable`")self._step_supports_amp_scaling=True# TODO(crcrpar): [low prec params & their higher prec copy]# Support AMP with FP16/BF16 model params which would need# higher prec copy of params to do update math in higher prec to# alleviate the loss of information.ifforeach:raiseRuntimeError("`fused` and `foreach` cannot be `True` together.")def__setstate__(self,state):super().__setstate__(state)forgroupinself.param_groups:group.setdefault("amsgrad",False)group.setdefault("maximize",False)group.setdefault("foreach",None)group.setdefault("capturable",False)group.setdefault("differentiable",False)fused=group.setdefault("fused",None)forpingroup["params"]:p_state=self.state.get(p,[])iflen(p_state)!=0andnottorch.is_tensor(p_state["step"]):step_val=float(p_state["step"])p_state["step"]=(torch.tensor(step_val,dtype=_get_scalar_dtype(is_fused=fused),device=p.device,)ifgroup["capturable"]orgroup["fused"]elsetorch.tensor(step_val,dtype=_get_scalar_dtype()))def_init_group(self,group,params_with_grad,grads,exp_avgs,exp_avg_sqs,max_exp_avg_sqs,state_steps,):has_complex=Falseforpingroup["params"]:ifp.gradisnotNone:has_complex|=torch.is_complex(p)params_with_grad.append(p)ifp.grad.is_sparse:raiseRuntimeError("Adam does not support sparse gradients, please consider SparseAdam instead")grads.append(p.grad)state=self.state[p]# Lazy state initializationiflen(state)==0:ifgroup["fused"]:_device_dtype_check_for_fused(p)# note(crcrpar): [special device hosting for step]# Deliberately host `step` on CPU if both capturable and fused are off.# This is because kernel launches are costly on CUDA and XLA.state["step"]=(torch.zeros((),dtype=_get_scalar_dtype(is_fused=group["fused"]),device=p.device,)ifgroup["capturable"]orgroup["fused"]elsetorch.tensor(0.0,dtype=_get_scalar_dtype()))# Exponential moving average of gradient valuesstate["exp_avg"]=torch.zeros_like(p,memory_format=torch.preserve_format)# Exponential moving average of squared gradient valuesstate["exp_avg_sq"]=torch.zeros_like(p,memory_format=torch.preserve_format)ifgroup["amsgrad"]:# Maintains max of all exp. moving avg. of sq. grad. valuesstate["max_exp_avg_sq"]=torch.zeros_like(p,memory_format=torch.preserve_format)exp_avgs.append(state["exp_avg"])exp_avg_sqs.append(state["exp_avg_sq"])ifgroup["amsgrad"]:max_exp_avg_sqs.append(state["max_exp_avg_sq"])ifgroup["differentiable"]andstate["step"].requires_grad:raiseRuntimeError("`requires_grad` is not supported for `step` in differentiable mode")# Foreach without capturable does not support a tensor lrif(group["foreach"]andtorch.is_tensor(group["lr"])andnotgroup["capturable"]):raiseRuntimeError("lr as a Tensor is not supported for capturable=False and foreach=True")state_steps.append(state["step"])returnhas_complex
[docs]@_use_grad_for_differentiabledefstep(self,closure=None):"""Perform a single optimization step. Args: closure (Callable, optional): A closure that reevaluates the model and returns the loss. """self._cuda_graph_capture_health_check()loss=NoneifclosureisnotNone:withtorch.enable_grad():loss=closure()forgroupinself.param_groups:params_with_grad:List[Tensor]=[]grads:List[Tensor]=[]exp_avgs:List[Tensor]=[]exp_avg_sqs:List[Tensor]=[]max_exp_avg_sqs:List[Tensor]=[]state_steps:List[Tensor]=[]beta1,beta2=group["betas"]has_complex=self._init_group(group,params_with_grad,grads,exp_avgs,exp_avg_sqs,max_exp_avg_sqs,state_steps,)adam(params_with_grad,grads,exp_avgs,exp_avg_sqs,max_exp_avg_sqs,state_steps,amsgrad=group["amsgrad"],has_complex=has_complex,beta1=beta1,beta2=beta2,lr=group["lr"],weight_decay=group["weight_decay"],eps=group["eps"],maximize=group["maximize"],foreach=group["foreach"],capturable=group["capturable"],differentiable=group["differentiable"],fused=group["fused"],grad_scale=getattr(self,"grad_scale",None),found_inf=getattr(self,"found_inf",None),)returnloss
Adam.__doc__=(r"""Implements Adam algorithm. .. math:: \begin{aligned} &\rule{110mm}{0.4pt} \\ &\textbf{input} : \gamma \text{ (lr)}, \beta_1, \beta_2 \text{ (betas)},\theta_0 \text{ (params)},f(\theta) \text{ (objective)} \\ &\hspace{13mm} \lambda \text{ (weight decay)}, \: \textit{amsgrad}, \:\textit{maximize}, \: \epsilon \text{ (epsilon)} \\ &\textbf{initialize} : m_0 \leftarrow 0 \text{ ( first moment)}, v_0\leftarrow 0 \text{ (second moment)},\: \widehat{v_0}^{max}\leftarrow 0\\[-1.ex] &\rule{110mm}{0.4pt} \\ &\textbf{for} \: t=1 \: \textbf{to} \: \ldots \: \textbf{do} \\ &\hspace{5mm}\textbf{if} \: \textit{maximize}: \\ &\hspace{10mm}g_t \leftarrow -\nabla_{\theta} f_t (\theta_{t-1}) \\ &\hspace{5mm}\textbf{else} \\ &\hspace{10mm}g_t \leftarrow \nabla_{\theta} f_t (\theta_{t-1}) \\ &\hspace{5mm}\textbf{if} \: \lambda \neq 0 \\ &\hspace{10mm} g_t \leftarrow g_t + \lambda \theta_{t-1} \\ &\hspace{5mm}m_t \leftarrow \beta_1 m_{t-1} + (1 - \beta_1) g_t \\ &\hspace{5mm}v_t \leftarrow \beta_2 v_{t-1} + (1-\beta_2) g^2_t \\ &\hspace{5mm}\widehat{m_t} \leftarrow m_t/\big(1-\beta_1^t \big) \\ &\hspace{5mm}\widehat{v_t} \leftarrow v_t/\big(1-\beta_2^t \big) \\ &\hspace{5mm}\textbf{if} \: amsgrad \\ &\hspace{10mm}\widehat{v_t}^{max} \leftarrow \mathrm{max}(\widehat{v_{t-1}}^{max}, \widehat{v_t}) \\ &\hspace{10mm}\theta_t \leftarrow \theta_{t-1} - \gamma \widehat{m_t}/ \big(\sqrt{\widehat{v_t}^{max}} + \epsilon \big) \\ &\hspace{5mm}\textbf{else} \\ &\hspace{10mm}\theta_t \leftarrow \theta_{t-1} - \gamma \widehat{m_t}/ \big(\sqrt{\widehat{v_t}} + \epsilon \big) \\ &\rule{110mm}{0.4pt} \\[-1.ex] &\bf{return} \: \theta_t \\[-1.ex] &\rule{110mm}{0.4pt} \\[-1.ex] \end{aligned} For further details regarding the algorithm we refer to `Adam: A Method for Stochastic Optimization`_. """+rf""" Args:{_params_doc} lr (float, Tensor, optional): learning rate (default: 1e-3). A tensor LR is not yet supported for all our implementations. Please use a float LR if you are not also specifying fused=True or capturable=True. betas (Tuple[float, float], optional): coefficients used for computing running averages of gradient and its square (default: (0.9, 0.999)) eps (float, optional): term added to the denominator to improve numerical stability (default: 1e-8) weight_decay (float, optional): weight decay (L2 penalty) (default: 0) amsgrad (bool, optional): whether to use the AMSGrad variant of this algorithm from the paper `On the Convergence of Adam and Beyond`_ (default: False){_foreach_doc}{_maximize_doc}{_capturable_doc}{_differentiable_doc}{_fused_doc} .. Note:: A prototype implementation of Adam and AdamW for MPS supports `torch.float32` and `torch.float16`. .. _Adam\: A Method for Stochastic Optimization: https://arxiv.org/abs/1412.6980 .. _On the Convergence of Adam and Beyond: https://openreview.net/forum?id=ryQu7f-RZ """)def_single_tensor_adam(params:List[Tensor],grads:List[Tensor],exp_avgs:List[Tensor],exp_avg_sqs:List[Tensor],max_exp_avg_sqs:List[Tensor],state_steps:List[Tensor],grad_scale:Optional[Tensor],found_inf:Optional[Tensor],*,amsgrad:bool,has_complex:bool,beta1:Union[float,Tensor],beta2:Union[float,Tensor],lr:Union[float,Tensor],weight_decay:float,eps:float,maximize:bool,capturable:bool,differentiable:bool,):assertgrad_scaleisNoneandfound_infisNoneiftorch.jit.is_scripting():# this assert is due to JIT being dumb and not realizing that the ops below# have overloads to handle both float and Tensor lrs, so we just assert it's# a float since most people using JIT are using floatsassertisinstance(lr,float)assertisinstance(beta1,float)assertisinstance(beta2,float)# We only shuffle around the beta when it is a Tensor, otherwise, we prefer# treating it as a scalar.# Note: ensure type declaration is under conditional check for isinstance# or else torchscript will get cranky about the DeviceDict type.ifisinstance(beta1,Tensor):beta1_dict:Optional[DeviceDtypeDict]={(beta1.device,beta1.dtype):beta1}else:beta1_dict=Nonefori,paraminenumerate(params):grad=grads[i]ifnotmaximizeelse-grads[i]exp_avg=exp_avgs[i]exp_avg_sq=exp_avg_sqs[i]step_t=state_steps[i]# If compiling, the compiler will handle cudagraph checks, see note [torch.compile x capturable]ifnottorch.compiler.is_compiling()andcapturable:capturable_supported_devices=_get_capturable_supported_devices()assert(param.device.type==step_t.device.typeandparam.device.typeincapturable_supported_devices),f"If capturable=True, params and state_steps must be on supported devices: {capturable_supported_devices}."# update stepstep_t+=1ifweight_decay!=0:grad=grad.add(param,alpha=weight_decay)iftorch.is_complex(param):grad=torch.view_as_real(grad)exp_avg=torch.view_as_real(exp_avg)exp_avg_sq=torch.view_as_real(exp_avg_sq)ifamsgrad:max_exp_avg_sqs[i]=torch.view_as_real(max_exp_avg_sqs[i])param=torch.view_as_real(param)device=param.deviceifbeta1_dictisnotNone:dtype=param.dtype# type: ignore[union-attr]# cast to workaround https://github.com/pytorch/pytorch/issues/140601key=(device,dtype)ifkeynotinbeta1_dict:beta1_dict[key]=beta1.to(device=device,dtype=dtype,non_blocking=True)# type: ignore[union-attr]device_beta1:Union[float,Tensor]=beta1_dict[key]else:device_beta1=beta1# Decay the first and second moment running average coefficientexp_avg.lerp_(grad,1-device_beta1)exp_avg_sq.mul_(beta2).addcmul_(grad,grad.conj(),value=1-beta2)ifcapturableordifferentiable:step=step_tbias_correction1=1-beta1**stepbias_correction2=1-beta2**stepstep_size=lr/bias_correction1step_size_neg=step_size.neg()bias_correction2_sqrt=bias_correction2.sqrt()ifamsgrad:# Maintains the maximum of all 2nd moment running avg. till nowifdifferentiable:max_exp_avg_sq=max_exp_avg_sqs[i].clone()else:max_exp_avg_sq=max_exp_avg_sqs[i]max_exp_avg_sqs[i].copy_(torch.maximum(max_exp_avg_sq,exp_avg_sq))# Uses the max. for normalizing running avg. of gradient# Folds in (admittedly ugly) 1-elem step_size math here to avoid extra param-set-sized read+write# (can't fold it into addcdiv_ below because addcdiv_ requires value is a Number, not a Tensor)denom=(max_exp_avg_sqs[i].sqrt()/(bias_correction2_sqrt*step_size_neg)).add_(eps/step_size_neg)else:denom=(exp_avg_sq.sqrt()/(bias_correction2_sqrt*step_size_neg)).add_(eps/step_size_neg)param.addcdiv_(exp_avg,denom)else:step=_get_value(step_t)bias_correction1=1-beta1**stepbias_correction2=1-beta2**stepstep_size=lr/bias_correction1bias_correction2_sqrt=bias_correction2**0.5ifamsgrad:# Maintains the maximum of all 2nd moment running avg. till nowtorch.maximum(max_exp_avg_sqs[i],exp_avg_sq,out=max_exp_avg_sqs[i])# Use the max. for normalizing running avg. of gradientdenom=(max_exp_avg_sqs[i].sqrt()/bias_correction2_sqrt).add_(eps)else:denom=(exp_avg_sq.sqrt()/bias_correction2_sqrt).add_(eps)param.addcdiv_(exp_avg,denom,value=-step_size)# Lastly, switch back to complex viewifamsgradandtorch.is_complex(params[i]):max_exp_avg_sqs[i]=torch.view_as_complex(max_exp_avg_sqs[i])def_multi_tensor_adam(params:List[Tensor],grads:List[Tensor],exp_avgs:List[Tensor],exp_avg_sqs:List[Tensor],max_exp_avg_sqs:List[Tensor],state_steps:List[Tensor],grad_scale:Optional[Tensor],found_inf:Optional[Tensor],*,amsgrad:bool,has_complex:bool,beta1:Union[float,Tensor],beta2:Union[float,Tensor],lr:Union[float,Tensor],weight_decay:float,eps:float,maximize:bool,capturable:bool,differentiable:bool,):iflen(params)==0:returnifisinstance(lr,Tensor)andnotcapturable:raiseRuntimeError("lr as a Tensor is not supported for capturable=False and foreach=True")ifisinstance(beta1,Tensor):ifnotcapturable:raiseValueError("beta1 as a Tensor is not supported for capturable=False and foreach=True")ifbeta1.numel()!=1:raiseValueError("Tensor beta1 must be 1-element")ifisinstance(beta2,Tensor):ifnotcapturable:raiseValueError("beta2 as a Tensor is not supported for capturable=False and foreach=True")ifbeta2.numel()!=1:raiseValueError("Tensor beta2 must be 1-element")# If compiling, the compiler will handle cudagraph checks, see note [torch.compile x capturable]ifnottorch.compiler.is_compiling()andcapturable:capturable_supported_devices=_get_capturable_supported_devices(supports_xla=False)assertall(p.device.type==step.device.typeandp.device.typeincapturable_supported_devicesforp,stepinzip(params,state_steps)),f"If capturable=True, params and state_steps must be on supported devices: {capturable_supported_devices}."assertgrad_scaleisNoneandfound_infisNoneassertnotdifferentiable,"_foreach ops don't support autograd"grouped_tensors=Optimizer._group_tensors_by_device_and_dtype([params,grads,exp_avgs,exp_avg_sqs,max_exp_avg_sqs,state_steps]# type: ignore[list-item])# We only shuffle around the beta when it is a Tensor and on CUDA, otherwise, we prefer# treating it as a scalar.beta1_dict:Optional[DeviceDict]=(# type: ignore[attr-defined]{beta1.device:beta1}ifisinstance(beta1,Tensor)andstr(beta1.device)!="cpu"elseNone)for(device_params_,device_grads_,device_exp_avgs_,device_exp_avg_sqs_,device_max_exp_avg_sqs_,device_state_steps_,),_ingrouped_tensors.values():device_params=cast(List[Tensor],device_params_)device_grads=cast(List[Tensor],device_grads_)device_exp_avgs=cast(List[Tensor],device_exp_avgs_)device_exp_avg_sqs=cast(List[Tensor],device_exp_avg_sqs_)device_state_steps=cast(List[Tensor],device_state_steps_)device=device_params[0].deviceifbeta1_dictisnotNoneanddevicenotinbeta1_dict:beta1_dict[device]=beta1.to(device=device,non_blocking=True)# type: ignore[union-attr, attr-defined]device_beta1=beta1_dict[device]ifbeta1_dictelsebeta1# Handle complex parametersifhas_complex:ifamsgrad:device_max_exp_avg_sqs=cast(List[Tensor],device_max_exp_avg_sqs_)_view_as_real(device_params,device_grads,device_exp_avgs,device_exp_avg_sqs,device_max_exp_avg_sqs,)else:_view_as_real(device_params,device_grads,device_exp_avgs,device_exp_avg_sqs)ifmaximize:device_grads=torch._foreach_neg(device_grads)# type: ignore[assignment]# Update steps# If steps are on CPU, foreach will fall back to the slow path, which is a for-loop calling t.add(1) over# and over. 1 will then be wrapped into a Tensor over and over again, which is slower than if we just# wrapped it once now. The alpha is required to assure we go to the right overload.ifnottorch.compiler.is_compiling()anddevice_state_steps[0].is_cpu:torch._foreach_add_(device_state_steps,torch.tensor(1.0,device="cpu"),alpha=1.0)else:torch._foreach_add_(device_state_steps,1)ifweight_decay!=0:# Re-use the intermediate memory (device_grads) already allocated for maximizeifmaximize:torch._foreach_add_(device_grads,device_params,alpha=weight_decay)else:device_grads=torch._foreach_add(# type: ignore[assignment]device_grads,device_params,alpha=weight_decay)# Decay the first and second moment running average coefficient# Use device beta1 if beta1 is a tensor to ensure all# tensors are on the same devicetorch._foreach_lerp_(device_exp_avgs,device_grads,1-device_beta1)torch._foreach_mul_(device_exp_avg_sqs,beta2)# Due to the strictness of the _foreach_addcmul API, we can't have a single# tensor scalar as the scalar arg (only python number is supported there)# as a result, separate out the value mul# Filed https://github.com/pytorch/pytorch/issues/139795ifisinstance(beta2,torch.Tensor):scaled_device_grads=torch._foreach_mul(device_grads,1-beta2)# type: ignore[assignment]value=1.0else:scaled_device_grads=device_grads# type: ignore[assignment]value=1-beta2torch._foreach_addcmul_(device_exp_avg_sqs,scaled_device_grads,device_grads,value)# Delete the local intermediate(s) since they won't be used anymore to save on peak memorydeldevice_gradsdelscaled_device_gradsbias_correction1:Union[Tuple[Tensor,...],List[Tensor]]bias_correction2:Union[Tuple[Tensor,...],List[Tensor]]bias_correction2_sqrt:Union[Tuple[Tensor,...],List[Tensor]]ifcapturable:bias_correction1=torch._foreach_pow(beta1,device_state_steps)# type: ignore[arg-type]bias_correction2=torch._foreach_pow(beta2,device_state_steps)# type: ignore[arg-type]# foreach_sub doesn't allow a scalar as the first argtorch._foreach_sub_(bias_correction1,1)torch._foreach_sub_(bias_correction2,1)# we do not negate bias_correction1 as it'll need to be negated later anywaytorch._foreach_neg_(bias_correction2)# foreach_div doesn't allow a scalar as the first argtorch._foreach_div_(bias_correction1,lr)torch._foreach_reciprocal_(bias_correction1)torch._foreach_sqrt_(bias_correction2)# Re-assign for clarity as we maintain minimal intermediates: we'll have# step_size = - lr / (1 - beta1 ^ t) where t = num_steps# bias_correction2_sqrt = sqrt(1 - beta2 ^ t)step_size=bias_correction1bias_correction2_sqrt=bias_correction2ifamsgrad:device_max_exp_avg_sqs=cast(List[Tensor],device_max_exp_avg_sqs_)# Maintains the maximum of all 2nd moment running avg. till nowtorch._foreach_maximum_(device_max_exp_avg_sqs,device_exp_avg_sqs)# type: ignore[assignment]# Set intermediate to the max. for normalizing running avg. of gradient when amsgradexp_avg_sq_sqrt=torch._foreach_sqrt(device_max_exp_avg_sqs)else:exp_avg_sq_sqrt=torch._foreach_sqrt(device_exp_avg_sqs)torch._foreach_div_(exp_avg_sq_sqrt,bias_correction2_sqrt)torch._foreach_add_(exp_avg_sq_sqrt,eps)torch._foreach_div_(exp_avg_sq_sqrt,step_size)# at this point, exp_avg_sq_sqrt = - (1 - beta^t) * [sqrt(exp_avg_sq / (1 - beta2^t)) + eps] / lrtorch._foreach_addcdiv_(device_params,device_exp_avgs,exp_avg_sq_sqrt)else:bias_correction1=[1-beta1**_get_value(step)forstepindevice_state_steps]bias_correction2=[1-beta2**_get_value(step)forstepindevice_state_steps]step_size=_stack_if_compiling([(lr/bc)*-1forbcinbias_correction1])bias_correction2_sqrt=[bc**0.5forbcinbias_correction2]# type: ignore[arg-type]ifamsgrad:device_max_exp_avg_sqs=cast(List[Tensor],device_max_exp_avg_sqs_)# Maintains the maximum of all 2nd moment running avg. till nowtorch._foreach_maximum_(device_max_exp_avg_sqs,device_exp_avg_sqs)# Use the max. for normalizing running avg. of gradientexp_avg_sq_sqrt=torch._foreach_sqrt(device_max_exp_avg_sqs)else:exp_avg_sq_sqrt=torch._foreach_sqrt(device_exp_avg_sqs)torch._foreach_div_(exp_avg_sq_sqrt,bias_correction2_sqrt)torch._foreach_add_(exp_avg_sq_sqrt,eps)torch._foreach_addcdiv_(device_params,device_exp_avgs,exp_avg_sq_sqrt,step_size# type: ignore[arg-type])def_fused_adam(params:List[Tensor],grads:List[Tensor],exp_avgs:List[Tensor],exp_avg_sqs:List[Tensor],max_exp_avg_sqs:List[Tensor],state_steps:List[Tensor],grad_scale:Optional[Tensor],found_inf:Optional[Tensor],*,amsgrad:bool,has_complex:bool,# Needed for consistency.beta1:float,beta2:float,lr:Union[float,Tensor],weight_decay:float,eps:float,maximize:bool,capturable:bool,# Needed for consistency.differentiable:bool,)->None:ifnotparams:returnifdifferentiable:raiseRuntimeError("Adam with fused=True does not support differentiable=True")grad_scale_dict:DeviceDict=({grad_scale.device:grad_scale}ifgrad_scaleisnotNoneelse{})found_inf_dict:DeviceDict=({found_inf.device:found_inf}iffound_infisnotNoneelse{})# We only shuffle around the lr when it is a Tensor and on CUDA, otherwise, we prefer# treating it as a scalar.lr_dict:Optional[DeviceDict]=({lr.device:lr}ifisinstance(lr,Tensor)andstr(lr.device)!="cpu"elseNone)grouped_tensors=Optimizer._group_tensors_by_device_and_dtype([params,grads,exp_avgs,exp_avg_sqs,max_exp_avg_sqs,state_steps]# type: ignore[list-item])for(device,_),((device_params_,device_grads_,device_exp_avgs_,device_exp_avg_sqs_,device_max_exp_avg_sqs,device_state_steps_,),_,)ingrouped_tensors.items():device_params=cast(List[Tensor],device_params_)device_grads=cast(List[Tensor],device_grads_)device_exp_avgs=cast(List[Tensor],device_exp_avgs_)device_exp_avg_sqs=cast(List[Tensor],device_exp_avg_sqs_)device_state_steps=cast(List[Tensor],device_state_steps_)ifdevice.type=="mps":# type: ignore[union-attr]assertfound_infisNoneandgrad_scaleisNonedevice_grad_scale,device_found_inf=None,Noneifgrad_scaleisnotNone:device_grad_scale=grad_scale_dict.setdefault(device,grad_scale.to(device,non_blocking=True))iffound_infisnotNone:device_found_inf=found_inf_dict.setdefault(device,found_inf.to(device,non_blocking=True))iflr_dictisnotNoneanddevicenotinlr_dict:lr_dict[device]=lr.to(device=device,non_blocking=True)# type: ignore[union-attr]lr=lr_dict[device]torch._foreach_add_(device_state_steps,1)torch._fused_adam_(device_params,device_grads,device_exp_avgs,device_exp_avg_sqs,device_max_exp_avg_sqs,# type: ignore[arg-type]device_state_steps,amsgrad=amsgrad,lr=lr,# type: ignore[arg-type]beta1=beta1,beta2=beta2,weight_decay=weight_decay,eps=eps,maximize=maximize,grad_scale=device_grad_scale,found_inf=device_found_inf,)ifdevice_found_infisnotNone:torch._foreach_sub_(device_state_steps,[device_found_inf]*len(device_state_steps))@_disable_dynamo_if_unsupported(single_tensor_fn=_single_tensor_adam)defadam(params:List[Tensor],grads:List[Tensor],exp_avgs:List[Tensor],exp_avg_sqs:List[Tensor],max_exp_avg_sqs:List[Tensor],state_steps:List[Tensor],# kwonly args with defaults are not supported by functions compiled with torchscript issue #70627# setting this as kwarg for now as functional API is compiled by torch/distributed/optimforeach:Optional[bool]=None,capturable:bool=False,differentiable:bool=False,fused:Optional[bool]=None,grad_scale:Optional[Tensor]=None,found_inf:Optional[Tensor]=None,has_complex:bool=False,*,amsgrad:bool,beta1:float,beta2:float,lr:Union[float,Tensor],weight_decay:float,eps:float,maximize:bool,):r"""Functional API that performs Adam algorithm computation. See :class:`~torch.optim.Adam` for details. """# Respect when the user inputs False/True for foreach or fused. We only want to change# the default when neither have been user-specified. Note that we default to foreach# and pass False to use_fused. This is not a mistake--we want to give the fused impl# bake-in time before making it the default, even if it is typically faster.iffusedisNoneandforeachisNone:_,foreach=_default_to_fused_or_foreach(params,differentiable,use_fused=False)# Do not flip on foreach for the unsupported case where lr is a Tensor and capturable=False.ifforeachandisinstance(lr,Tensor)andnotcapturable:foreach=FalseiffusedisNone:fused=FalseifforeachisNone:foreach=False# this check is slow during compilation, so we skip it# if it's strictly needed we can add this check back in dynamoifnottorch.compiler.is_compiling()andnotall(isinstance(t,torch.Tensor)fortinstate_steps):raiseRuntimeError("API has changed, `state_steps` argument must contain a list of singleton tensors")ifforeachandtorch.jit.is_scripting():raiseRuntimeError("torch.jit.script not supported with foreach optimizers")iffusedandtorch.jit.is_scripting():raiseRuntimeError("torch.jit.script not supported with fused optimizers")iffusedandnottorch.jit.is_scripting():func=_fused_adamelifforeachandnottorch.jit.is_scripting():func=_multi_tensor_adamelse:func=_single_tensor_adamfunc(params,grads,exp_avgs,exp_avg_sqs,max_exp_avg_sqs,state_steps,amsgrad=amsgrad,has_complex=has_complex,beta1=beta1,beta2=beta2,lr=lr,weight_decay=weight_decay,eps=eps,maximize=maximize,capturable=capturable,differentiable=differentiable,grad_scale=grad_scale,found_inf=found_inf,)
Docs
Access comprehensive developer documentation for PyTorch
To analyze traffic and optimize your experience, we serve cookies on this site. By clicking or navigating, you agree to allow our usage of cookies. As the current maintainers of this site, Facebook’s Cookies Policy applies. Learn more, including about available controls: Cookies Policy.