importtorchfromfunctoolsimportreducefrom.optimizerimportOptimizer__all__=['LBFGS']def_cubic_interpolate(x1,f1,g1,x2,f2,g2,bounds=None):# ported from https://github.com/torch/optim/blob/master/polyinterp.lua# Compute bounds of interpolation areaifboundsisnotNone:xmin_bound,xmax_bound=boundselse:xmin_bound,xmax_bound=(x1,x2)ifx1<=x2else(x2,x1)# Code for most common case: cubic interpolation of 2 points# w/ function and derivative values for both# Solution in this case (where x2 is the farthest point):# d1 = g1 + g2 - 3*(f1-f2)/(x1-x2);# d2 = sqrt(d1^2 - g1*g2);# min_pos = x2 - (x2 - x1)*((g2 + d2 - d1)/(g2 - g1 + 2*d2));# t_new = min(max(min_pos,xmin_bound),xmax_bound);d1=g1+g2-3*(f1-f2)/(x1-x2)d2_square=d1**2-g1*g2ifd2_square>=0:d2=d2_square.sqrt()ifx1<=x2:min_pos=x2-(x2-x1)*((g2+d2-d1)/(g2-g1+2*d2))else:min_pos=x1-(x1-x2)*((g1+d2-d1)/(g1-g2+2*d2))returnmin(max(min_pos,xmin_bound),xmax_bound)else:return(xmin_bound+xmax_bound)/2.def_strong_wolfe(obj_func,x,t,d,f,g,gtd,c1=1e-4,c2=0.9,tolerance_change=1e-9,max_ls=25):# ported from https://github.com/torch/optim/blob/master/lswolfe.luad_norm=d.abs().max()g=g.clone(memory_format=torch.contiguous_format)# evaluate objective and gradient using initial stepf_new,g_new=obj_func(x,t,d)ls_func_evals=1gtd_new=g_new.dot(d)# bracket an interval containing a point satisfying the Wolfe criteriat_prev,f_prev,g_prev,gtd_prev=0,f,g,gtddone=Falsels_iter=0whilels_iter<max_ls:# check conditionsiff_new>(f+c1*t*gtd)or(ls_iter>1andf_new>=f_prev):bracket=[t_prev,t]bracket_f=[f_prev,f_new]bracket_g=[g_prev,g_new.clone(memory_format=torch.contiguous_format)]bracket_gtd=[gtd_prev,gtd_new]breakifabs(gtd_new)<=-c2*gtd:bracket=[t]bracket_f=[f_new]bracket_g=[g_new]done=Truebreakifgtd_new>=0:bracket=[t_prev,t]bracket_f=[f_prev,f_new]bracket_g=[g_prev,g_new.clone(memory_format=torch.contiguous_format)]bracket_gtd=[gtd_prev,gtd_new]break# interpolatemin_step=t+0.01*(t-t_prev)max_step=t*10tmp=tt=_cubic_interpolate(t_prev,f_prev,gtd_prev,t,f_new,gtd_new,bounds=(min_step,max_step))# next stept_prev=tmpf_prev=f_newg_prev=g_new.clone(memory_format=torch.contiguous_format)gtd_prev=gtd_newf_new,g_new=obj_func(x,t,d)ls_func_evals+=1gtd_new=g_new.dot(d)ls_iter+=1# reached max number of iterations?ifls_iter==max_ls:bracket=[0,t]bracket_f=[f,f_new]bracket_g=[g,g_new]# zoom phase: we now have a point satisfying the criteria, or# a bracket around it. We refine the bracket until we find the# exact point satisfying the criteriainsuf_progress=False# find high and low points in bracketlow_pos,high_pos=(0,1)ifbracket_f[0]<=bracket_f[-1]else(1,0)whilenotdoneandls_iter<max_ls:# line-search bracket is so smallifabs(bracket[1]-bracket[0])*d_norm<tolerance_change:break# compute new trial valuet=_cubic_interpolate(bracket[0],bracket_f[0],bracket_gtd[0],bracket[1],bracket_f[1],bracket_gtd[1])# test that we are making sufficient progress:# in case `t` is so close to boundary, we mark that we are making# insufficient progress, and if# + we have made insufficient progress in the last step, or# + `t` is at one of the boundary,# we will move `t` to a position which is `0.1 * len(bracket)`# away from the nearest boundary point.eps=0.1*(max(bracket)-min(bracket))ifmin(max(bracket)-t,t-min(bracket))<eps:# interpolation close to boundaryifinsuf_progressort>=max(bracket)ort<=min(bracket):# evaluate at 0.1 away from boundaryifabs(t-max(bracket))<abs(t-min(bracket)):t=max(bracket)-epselse:t=min(bracket)+epsinsuf_progress=Falseelse:insuf_progress=Trueelse:insuf_progress=False# Evaluate new pointf_new,g_new=obj_func(x,t,d)ls_func_evals+=1gtd_new=g_new.dot(d)ls_iter+=1iff_new>(f+c1*t*gtd)orf_new>=bracket_f[low_pos]:# Armijo condition not satisfied or not lower than lowest pointbracket[high_pos]=tbracket_f[high_pos]=f_newbracket_g[high_pos]=g_new.clone(memory_format=torch.contiguous_format)bracket_gtd[high_pos]=gtd_newlow_pos,high_pos=(0,1)ifbracket_f[0]<=bracket_f[1]else(1,0)else:ifabs(gtd_new)<=-c2*gtd:# Wolfe conditions satisfieddone=Trueelifgtd_new*(bracket[high_pos]-bracket[low_pos])>=0:# old high becomes new lowbracket[high_pos]=bracket[low_pos]bracket_f[high_pos]=bracket_f[low_pos]bracket_g[high_pos]=bracket_g[low_pos]bracket_gtd[high_pos]=bracket_gtd[low_pos]# new point becomes new lowbracket[low_pos]=tbracket_f[low_pos]=f_newbracket_g[low_pos]=g_new.clone(memory_format=torch.contiguous_format)bracket_gtd[low_pos]=gtd_new# return stufft=bracket[low_pos]f_new=bracket_f[low_pos]g_new=bracket_g[low_pos]returnf_new,g_new,t,ls_func_evals
[docs]classLBFGS(Optimizer):"""Implements L-BFGS algorithm, heavily inspired by `minFunc <https://www.cs.ubc.ca/~schmidtm/Software/minFunc.html>`_. .. warning:: This optimizer doesn't support per-parameter options and parameter groups (there can be only one). .. warning:: Right now all parameters have to be on a single device. This will be improved in the future. .. note:: This is a very memory intensive optimizer (it requires additional ``param_bytes * (history_size + 1)`` bytes). If it doesn't fit in memory try reducing the history size, or use a different algorithm. Args: lr (float): learning rate (default: 1) max_iter (int): maximal number of iterations per optimization step (default: 20) max_eval (int): maximal number of function evaluations per optimization step (default: max_iter * 1.25). tolerance_grad (float): termination tolerance on first order optimality (default: 1e-5). tolerance_change (float): termination tolerance on function value/parameter changes (default: 1e-9). history_size (int): update history size (default: 100). line_search_fn (str): either 'strong_wolfe' or None (default: None). """def__init__(self,params,lr=1,max_iter=20,max_eval=None,tolerance_grad=1e-7,tolerance_change=1e-9,history_size=100,line_search_fn=None):ifmax_evalisNone:max_eval=max_iter*5//4defaults=dict(lr=lr,max_iter=max_iter,max_eval=max_eval,tolerance_grad=tolerance_grad,tolerance_change=tolerance_change,history_size=history_size,line_search_fn=line_search_fn)super(LBFGS,self).__init__(params,defaults)iflen(self.param_groups)!=1:raiseValueError("LBFGS doesn't support per-parameter options ""(parameter groups)")self._params=self.param_groups[0]['params']self._numel_cache=Nonedef_numel(self):ifself._numel_cacheisNone:self._numel_cache=reduce(lambdatotal,p:total+p.numel(),self._params,0)returnself._numel_cachedef_gather_flat_grad(self):views=[]forpinself._params:ifp.gradisNone:view=p.new(p.numel()).zero_()elifp.grad.is_sparse:view=p.grad.to_dense().view(-1)else:view=p.grad.view(-1)views.append(view)returntorch.cat(views,0)def_add_grad(self,step_size,update):offset=0forpinself._params:numel=p.numel()# view as to avoid deprecated pointwise semanticsp.add_(update[offset:offset+numel].view_as(p),alpha=step_size)offset+=numelassertoffset==self._numel()def_clone_param(self):return[p.clone(memory_format=torch.contiguous_format)forpinself._params]def_set_param(self,params_data):forp,pdatainzip(self._params,params_data):p.copy_(pdata)def_directional_evaluate(self,closure,x,t,d):self._add_grad(t,d)loss=float(closure())flat_grad=self._gather_flat_grad()self._set_param(x)returnloss,flat_grad
[docs]@torch.no_grad()defstep(self,closure):"""Performs a single optimization step. Args: closure (Callable): A closure that reevaluates the model and returns the loss. """assertlen(self.param_groups)==1# Make sure the closure is always called with grad enabledclosure=torch.enable_grad()(closure)group=self.param_groups[0]lr=group['lr']max_iter=group['max_iter']max_eval=group['max_eval']tolerance_grad=group['tolerance_grad']tolerance_change=group['tolerance_change']line_search_fn=group['line_search_fn']history_size=group['history_size']# NOTE: LBFGS has only global state, but we register it as state for# the first param, because this helps with casting in load_state_dictstate=self.state[self._params[0]]state.setdefault('func_evals',0)state.setdefault('n_iter',0)# evaluate initial f(x) and df/dxorig_loss=closure()loss=float(orig_loss)current_evals=1state['func_evals']+=1flat_grad=self._gather_flat_grad()opt_cond=flat_grad.abs().max()<=tolerance_grad# optimal conditionifopt_cond:returnorig_loss# tensors cached in state (for tracing)d=state.get('d')t=state.get('t')old_dirs=state.get('old_dirs')old_stps=state.get('old_stps')ro=state.get('ro')H_diag=state.get('H_diag')prev_flat_grad=state.get('prev_flat_grad')prev_loss=state.get('prev_loss')n_iter=0# optimize for a max of max_iter iterationswhilen_iter<max_iter:# keep track of nb of iterationsn_iter+=1state['n_iter']+=1############################################################# compute gradient descent direction############################################################ifstate['n_iter']==1:d=flat_grad.neg()old_dirs=[]old_stps=[]ro=[]H_diag=1else:# do lbfgs update (update memory)y=flat_grad.sub(prev_flat_grad)s=d.mul(t)ys=y.dot(s)# y*sifys>1e-10:# updating memoryiflen(old_dirs)==history_size:# shift history by one (limited-memory)old_dirs.pop(0)old_stps.pop(0)ro.pop(0)# store new direction/stepold_dirs.append(y)old_stps.append(s)ro.append(1./ys)# update scale of initial Hessian approximationH_diag=ys/y.dot(y)# (y*y)# compute the approximate (L-BFGS) inverse Hessian# multiplied by the gradientnum_old=len(old_dirs)if'al'notinstate:state['al']=[None]*history_sizeal=state['al']# iteration in L-BFGS loop collapsed to use just one bufferq=flat_grad.neg()foriinrange(num_old-1,-1,-1):al[i]=old_stps[i].dot(q)*ro[i]q.add_(old_dirs[i],alpha=-al[i])# multiply by initial Hessian# r/d is the final directiond=r=torch.mul(q,H_diag)foriinrange(num_old):be_i=old_dirs[i].dot(r)*ro[i]r.add_(old_stps[i],alpha=al[i]-be_i)ifprev_flat_gradisNone:prev_flat_grad=flat_grad.clone(memory_format=torch.contiguous_format)else:prev_flat_grad.copy_(flat_grad)prev_loss=loss############################################################# compute step length############################################################# reset initial guess for step sizeifstate['n_iter']==1:t=min(1.,1./flat_grad.abs().sum())*lrelse:t=lr# directional derivativegtd=flat_grad.dot(d)# g * d# directional derivative is below toleranceifgtd>-tolerance_change:break# optional line search: user functionls_func_evals=0ifline_search_fnisnotNone:# perform line search, using user functionifline_search_fn!="strong_wolfe":raiseRuntimeError("only 'strong_wolfe' is supported")else:x_init=self._clone_param()defobj_func(x,t,d):returnself._directional_evaluate(closure,x,t,d)loss,flat_grad,t,ls_func_evals=_strong_wolfe(obj_func,x_init,t,d,loss,flat_grad,gtd)self._add_grad(t,d)opt_cond=flat_grad.abs().max()<=tolerance_gradelse:# no line search, simply move with fixed-stepself._add_grad(t,d)ifn_iter!=max_iter:# re-evaluate function only if not in last iteration# the reason we do this: in a stochastic setting,# no use to re-evaluate that function herewithtorch.enable_grad():loss=float(closure())flat_grad=self._gather_flat_grad()opt_cond=flat_grad.abs().max()<=tolerance_gradls_func_evals=1# update func evalcurrent_evals+=ls_func_evalsstate['func_evals']+=ls_func_evals############################################################# check conditions############################################################ifn_iter==max_iter:breakifcurrent_evals>=max_eval:break# optimal conditionifopt_cond:break# lack of progressifd.mul(t).abs().max()<=tolerance_change:breakifabs(loss-prev_loss)<tolerance_change:breakstate['d']=dstate['t']=tstate['old_dirs']=old_dirsstate['old_stps']=old_stpsstate['ro']=rostate['H_diag']=H_diagstate['prev_flat_grad']=prev_flat_gradstate['prev_loss']=prev_lossreturnorig_loss
Docs
Access comprehensive developer documentation for PyTorch
To analyze traffic and optimize your experience, we serve cookies on this site. By clicking or navigating, you agree to allow our usage of cookies. As the current maintainers of this site, Facebook’s Cookies Policy applies. Learn more, including about available controls: Cookies Policy.