[docs]classRprop(Optimizer):r"""Implements the resilient backpropagation algorithm. .. math:: \begin{aligned} &\rule{110mm}{0.4pt} \\ &\textbf{input} : \theta_0 \in \mathbf{R}^d \text{ (params)},f(\theta) \text{ (objective)}, \\ &\hspace{13mm} \eta_{+/-} \text{ (etaplus, etaminus)}, \Gamma_{max/min} \text{ (step sizes)} \\ &\textbf{initialize} : g^0_{prev} \leftarrow 0, \: \eta_0 \leftarrow \text{lr (learning rate)} \\ &\rule{110mm}{0.4pt} \\ &\textbf{for} \: t=1 \: \textbf{to} \: \ldots \: \textbf{do} \\ &\hspace{5mm}g_t \leftarrow \nabla_{\theta} f_t (\theta_{t-1}) \\ &\hspace{5mm} \textbf{for} \text{ } i = 0, 1, \ldots, d-1 \: \mathbf{do} \\ &\hspace{10mm} \textbf{if} \: g^i_{prev} g^i_t > 0 \\ &\hspace{15mm} \eta^i_t \leftarrow \mathrm{min}(\eta^i_{t-1} \eta_{+}, \Gamma_{max}) \\ &\hspace{10mm} \textbf{else if} \: g^i_{prev} g^i_t < 0 \\ &\hspace{15mm} \eta^i_t \leftarrow \mathrm{max}(\eta^i_{t-1} \eta_{-}, \Gamma_{min}) \\ &\hspace{10mm} \textbf{else} \: \\ &\hspace{15mm} \eta^i_t \leftarrow \eta^i_{t-1} \\ &\hspace{5mm}\theta_t \leftarrow \theta_{t-1}- \eta_t \mathrm{sign}(g_t) \\ &\hspace{5mm}g_{prev} \leftarrow g_t \\ &\rule{110mm}{0.4pt} \\[-1.ex] &\bf{return} \: \theta_t \\[-1.ex] &\rule{110mm}{0.4pt} \\[-1.ex] \end{aligned} For further details regarding the algorithm we refer to the paper `A Direct Adaptive Method for Faster Backpropagation Learning: The RPROP Algorithm <http://citeseerx.ist.psu.edu/viewdoc/summary?doi=10.1.1.21.1417>`_. Args: params (iterable): iterable of parameters to optimize or dicts defining parameter groups lr (float, optional): learning rate (default: 1e-2) etas (Tuple[float, float], optional): pair of (etaminus, etaplis), that are multiplicative increase and decrease factors (default: (0.5, 1.2)) step_sizes (Tuple[float, float], optional): a pair of minimal and maximal allowed step sizes (default: (1e-6, 50)) """def__init__(self,params,lr=1e-2,etas=(0.5,1.2),step_sizes=(1e-6,50)):ifnot0.0<=lr:raiseValueError("Invalid learning rate: {}".format(lr))ifnot0.0<etas[0]<1.0<etas[1]:raiseValueError("Invalid eta values: {}, {}".format(etas[0],etas[1]))defaults=dict(lr=lr,etas=etas,step_sizes=step_sizes)super(Rprop,self).__init__(params,defaults)
[docs]@torch.no_grad()defstep(self,closure=None):"""Performs a single optimization step. Args: closure (callable, optional): A closure that reevaluates the model and returns the loss. """loss=NoneifclosureisnotNone:withtorch.enable_grad():loss=closure()forgroupinself.param_groups:params=[]grads=[]prevs=[]step_sizes=[]forpingroup['params']:ifp.gradisNone:continueparams.append(p)grad=p.gradifgrad.is_sparse:raiseRuntimeError('Rprop does not support sparse gradients')grads.append(grad)state=self.state[p]# State initializationiflen(state)==0:state['step']=0state['prev']=torch.zeros_like(p,memory_format=torch.preserve_format)state['step_size']=grad.new().resize_as_(grad).fill_(group['lr'])prevs.append(state['prev'])step_sizes.append(state['step_size'])etaminus,etaplus=group['etas']step_size_min,step_size_max=group['step_sizes']state['step']+=1F.rprop(params,grads,prevs,step_sizes,step_size_min=step_size_min,step_size_max=step_size_max,etaminus=etaminus,etaplus=etaplus)returnloss
Docs
Access comprehensive developer documentation for PyTorch
To analyze traffic and optimize your experience, we serve cookies on this site. By clicking or navigating, you agree to allow our usage of cookies. As the current maintainers of this site, Facebook’s Cookies Policy applies. Learn more, including about available controls: Cookies Policy.