[docs]classRprop(Optimizer):r"""Implements the resilient backpropagation algorithm. .. math:: \begin{aligned} &\rule{110mm}{0.4pt} \\ &\textbf{input} : \theta_0 \in \mathbf{R}^d \text{ (params)},f(\theta) \text{ (objective)}, \\ &\hspace{13mm} \eta_{+/-} \text{ (etaplus, etaminus)}, \Gamma_{max/min} \text{ (step sizes)} \\ &\textbf{initialize} : g^0_{prev} \leftarrow 0, \: \eta_0 \leftarrow \text{lr (learning rate)} \\ &\rule{110mm}{0.4pt} \\ &\textbf{for} \: t=1 \: \textbf{to} \: \ldots \: \textbf{do} \\ &\hspace{5mm}g_t \leftarrow \nabla_{\theta} f_t (\theta_{t-1}) \\ &\hspace{5mm} \textbf{for} \text{ } i = 0, 1, \ldots, d-1 \: \mathbf{do} \\ &\hspace{10mm} \textbf{if} \: g^i_{prev} g^i_t > 0 \\ &\hspace{15mm} \eta^i_t \leftarrow \mathrm{min}(\eta^i_{t-1} \eta_{+}, \Gamma_{max}) \\ &\hspace{10mm} \textbf{else if} \: g^i_{prev} g^i_t < 0 \\ &\hspace{15mm} \eta^i_t \leftarrow \mathrm{max}(\eta^i_{t-1} \eta_{-}, \Gamma_{min}) \\ &\hspace{15mm} g^i_t \leftarrow 0 \\ &\hspace{10mm} \textbf{else} \: \\ &\hspace{15mm} \eta^i_t \leftarrow \eta^i_{t-1} \\ &\hspace{5mm}\theta_t \leftarrow \theta_{t-1}- \eta_t \mathrm{sign}(g_t) \\ &\hspace{5mm}g_{prev} \leftarrow g_t \\ &\rule{110mm}{0.4pt} \\[-1.ex] &\bf{return} \: \theta_t \\[-1.ex] &\rule{110mm}{0.4pt} \\[-1.ex] \end{aligned} For further details regarding the algorithm we refer to the paper `A Direct Adaptive Method for Faster Backpropagation Learning: The RPROP Algorithm <http://citeseerx.ist.psu.edu/viewdoc/summary?doi=10.1.1.21.1417>`_. Args: params (iterable): iterable of parameters to optimize or dicts defining parameter groups lr (float, optional): learning rate (default: 1e-2) etas (Tuple[float, float], optional): pair of (etaminus, etaplis), that are multiplicative increase and decrease factors (default: (0.5, 1.2)) step_sizes (Tuple[float, float], optional): a pair of minimal and maximal allowed step sizes (default: (1e-6, 50)) foreach (bool, optional): whether foreach implementation of optimizer is used (default: None) """def__init__(self,params,lr=1e-2,etas=(0.5,1.2),step_sizes=(1e-6,50),foreach:Optional[bool]=None):ifnot0.0<=lr:raiseValueError("Invalid learning rate: {}".format(lr))ifnot0.0<etas[0]<1.0<etas[1]:raiseValueError("Invalid eta values: {}, {}".format(etas[0],etas[1]))defaults=dict(lr=lr,etas=etas,step_sizes=step_sizes,foreach=foreach)super(Rprop,self).__init__(params,defaults)def__setstate__(self,state):super().__setstate__(state)forgroupinself.param_groups:group.setdefault('foreach',None)
[docs]@torch.no_grad()defstep(self,closure=None):"""Performs a single optimization step. Args: closure (callable, optional): A closure that reevaluates the model and returns the loss. """loss=NoneifclosureisnotNone:withtorch.enable_grad():loss=closure()forgroupinself.param_groups:params=[]grads=[]prevs=[]step_sizes=[]etaminus,etaplus=group['etas']step_size_min,step_size_max=group['step_sizes']foreach=group['foreach']forpingroup['params']:ifp.gradisNone:continueparams.append(p)grad=p.gradifgrad.is_sparse:raiseRuntimeError('Rprop does not support sparse gradients')grads.append(grad)state=self.state[p]# State initializationiflen(state)==0:state['step']=0state['prev']=torch.zeros_like(p,memory_format=torch.preserve_format)state['step_size']=grad.new().resize_as_(grad).fill_(group['lr'])prevs.append(state['prev'])step_sizes.append(state['step_size'])state['step']+=1rprop(params,grads,prevs,step_sizes,step_size_min=step_size_min,step_size_max=step_size_max,etaminus=etaminus,etaplus=etaplus,foreach=foreach)returnloss
defrprop(params:List[Tensor],grads:List[Tensor],prevs:List[Tensor],step_sizes:List[Tensor],# kwonly args with defaults are not supported by functions compiled with torchscript issue #70627# setting this as kwarg for now as functional API is compiled by torch/distributed/optimforeach:bool=None,*,step_size_min:float,step_size_max:float,etaminus:float,etaplus:float):r"""Functional API that performs rprop algorithm computation. See :class:`~torch.optim.Rprop` for details. """ifforeachisNone:# Placeholder for more complex foreach logic to be added when value is not setforeach=Falseifforeachandtorch.jit.is_scripting():raiseRuntimeError('torch.jit.script not supported with foreach optimizers')ifforeachandnottorch.jit.is_scripting():func=_multi_tensor_rpropelse:func=_single_tensor_rpropfunc(params,grads,prevs,step_sizes,step_size_min=step_size_min,step_size_max=step_size_max,etaminus=etaminus,etaplus=etaplus)def_single_tensor_rprop(params:List[Tensor],grads:List[Tensor],prevs:List[Tensor],step_sizes:List[Tensor],*,step_size_min:float,step_size_max:float,etaminus:float,etaplus:float):fori,paraminenumerate(params):grad=grads[i]prev=prevs[i]step_size=step_sizes[i]sign=grad.mul(prev).sign()sign[sign.gt(0)]=etaplussign[sign.lt(0)]=etaminussign[sign.eq(0)]=1# update stepsizes with step size updatesstep_size.mul_(sign).clamp_(step_size_min,step_size_max)# for dir<0, dfdx=0# for dir>=0 dfdx=dfdxgrad=grad.clone(memory_format=torch.preserve_format)grad[sign.eq(etaminus)]=0# update parametersparam.addcmul_(grad.sign(),step_size,value=-1)prev.copy_(grad)def_multi_tensor_rprop(params:List[Tensor],grads:List[Tensor],prevs:List[Tensor],step_sizes:List[Tensor],*,step_size_min:float,step_size_max:float,etaminus:float,etaplus:float):iflen(params)==0:returnsigns=torch._foreach_mul(grads,prevs)signs=[s.sign()forsinsigns]forsigninsigns:sign[sign.gt(0)]=etaplussign[sign.lt(0)]=etaminussign[sign.eq(0)]=1# update stepsizes with step size updatestorch._foreach_mul_(step_sizes,signs)forstep_sizeinstep_sizes:step_size.clamp_(step_size_min,step_size_max)# for dir<0, dfdx=0# for dir>=0 dfdx=dfdxforiinrange(len(grads)):grads[i]=grads[i].clone(memory_format=torch.preserve_format)grads[i][signs[i].eq(etaminus)]=0# update parametersgrad_signs=[grad.sign()forgradingrads]torch._foreach_addcmul_(params,grad_signs,step_sizes,value=-1)foriinrange(len(prevs)):prevs[i].copy_(grads[i])
Docs
Access comprehensive developer documentation for PyTorch
To analyze traffic and optimize your experience, we serve cookies on this site. By clicking or navigating, you agree to allow our usage of cookies. As the current maintainers of this site, Facebook’s Cookies Policy applies. Learn more, including about available controls: Cookies Policy.