Shortcuts

Source code for torch.optim.lbfgs

# mypy: allow-untyped-defs
from typing import Optional

import torch
from .optimizer import Optimizer, ParamsT

__all__ = ["LBFGS"]


def _cubic_interpolate(x1, f1, g1, x2, f2, g2, bounds=None):
    # ported from https://github.com/torch/optim/blob/master/polyinterp.lua
    # Compute bounds of interpolation area
    if bounds is not None:
        xmin_bound, xmax_bound = bounds
    else:
        xmin_bound, xmax_bound = (x1, x2) if x1 <= x2 else (x2, x1)

    # Code for most common case: cubic interpolation of 2 points
    #   w/ function and derivative values for both
    # Solution in this case (where x2 is the farthest point):
    #   d1 = g1 + g2 - 3*(f1-f2)/(x1-x2);
    #   d2 = sqrt(d1^2 - g1*g2);
    #   min_pos = x2 - (x2 - x1)*((g2 + d2 - d1)/(g2 - g1 + 2*d2));
    #   t_new = min(max(min_pos,xmin_bound),xmax_bound);
    d1 = g1 + g2 - 3 * (f1 - f2) / (x1 - x2)
    d2_square = d1**2 - g1 * g2
    if d2_square >= 0:
        d2 = d2_square.sqrt()
        if x1 <= x2:
            min_pos = x2 - (x2 - x1) * ((g2 + d2 - d1) / (g2 - g1 + 2 * d2))
        else:
            min_pos = x1 - (x1 - x2) * ((g1 + d2 - d1) / (g1 - g2 + 2 * d2))
        return min(max(min_pos, xmin_bound), xmax_bound)
    else:
        return (xmin_bound + xmax_bound) / 2.0


def _strong_wolfe(
    obj_func, x, t, d, f, g, gtd, c1=1e-4, c2=0.9, tolerance_change=1e-9, max_ls=25
):
    # ported from https://github.com/torch/optim/blob/master/lswolfe.lua
    d_norm = d.abs().max()
    g = g.clone(memory_format=torch.contiguous_format)
    # evaluate objective and gradient using initial step
    f_new, g_new = obj_func(x, t, d)
    ls_func_evals = 1
    gtd_new = g_new.dot(d)

    # bracket an interval containing a point satisfying the Wolfe criteria
    t_prev, f_prev, g_prev, gtd_prev = 0, f, g, gtd
    done = False
    ls_iter = 0
    while ls_iter < max_ls:
        # check conditions
        if f_new > (f + c1 * t * gtd) or (ls_iter > 1 and f_new >= f_prev):
            bracket = [t_prev, t]
            bracket_f = [f_prev, f_new]
            bracket_g = [g_prev, g_new.clone(memory_format=torch.contiguous_format)]
            bracket_gtd = [gtd_prev, gtd_new]
            break

        if abs(gtd_new) <= -c2 * gtd:
            bracket = [t]
            bracket_f = [f_new]
            bracket_g = [g_new]
            done = True
            break

        if gtd_new >= 0:
            bracket = [t_prev, t]
            bracket_f = [f_prev, f_new]
            bracket_g = [g_prev, g_new.clone(memory_format=torch.contiguous_format)]
            bracket_gtd = [gtd_prev, gtd_new]
            break

        # interpolate
        min_step = t + 0.01 * (t - t_prev)
        max_step = t * 10
        tmp = t
        t = _cubic_interpolate(
            t_prev, f_prev, gtd_prev, t, f_new, gtd_new, bounds=(min_step, max_step)
        )

        # next step
        t_prev = tmp
        f_prev = f_new
        g_prev = g_new.clone(memory_format=torch.contiguous_format)
        gtd_prev = gtd_new
        f_new, g_new = obj_func(x, t, d)
        ls_func_evals += 1
        gtd_new = g_new.dot(d)
        ls_iter += 1

    # reached max number of iterations?
    if ls_iter == max_ls:
        bracket = [0, t]
        bracket_f = [f, f_new]
        bracket_g = [g, g_new]

    # zoom phase: we now have a point satisfying the criteria, or
    # a bracket around it. We refine the bracket until we find the
    # exact point satisfying the criteria
    insuf_progress = False
    # find high and low points in bracket
    low_pos, high_pos = (0, 1) if bracket_f[0] <= bracket_f[-1] else (1, 0)  # type: ignore[possibly-undefined]
    while not done and ls_iter < max_ls:
        # line-search bracket is so small
        if abs(bracket[1] - bracket[0]) * d_norm < tolerance_change:  # type: ignore[possibly-undefined]
            break

        # compute new trial value
        t = _cubic_interpolate(
            bracket[0],
            bracket_f[0],
            bracket_gtd[0],  # type: ignore[possibly-undefined]
            bracket[1],
            bracket_f[1],
            bracket_gtd[1],
        )

        # test that we are making sufficient progress:
        # in case `t` is so close to boundary, we mark that we are making
        # insufficient progress, and if
        #   + we have made insufficient progress in the last step, or
        #   + `t` is at one of the boundary,
        # we will move `t` to a position which is `0.1 * len(bracket)`
        # away from the nearest boundary point.
        eps = 0.1 * (max(bracket) - min(bracket))
        if min(max(bracket) - t, t - min(bracket)) < eps:
            # interpolation close to boundary
            if insuf_progress or t >= max(bracket) or t <= min(bracket):
                # evaluate at 0.1 away from boundary
                if abs(t - max(bracket)) < abs(t - min(bracket)):
                    t = max(bracket) - eps
                else:
                    t = min(bracket) + eps
                insuf_progress = False
            else:
                insuf_progress = True
        else:
            insuf_progress = False

        # Evaluate new point
        f_new, g_new = obj_func(x, t, d)
        ls_func_evals += 1
        gtd_new = g_new.dot(d)
        ls_iter += 1

        if f_new > (f + c1 * t * gtd) or f_new >= bracket_f[low_pos]:
            # Armijo condition not satisfied or not lower than lowest point
            bracket[high_pos] = t
            bracket_f[high_pos] = f_new
            bracket_g[high_pos] = g_new.clone(memory_format=torch.contiguous_format)  # type: ignore[possibly-undefined]
            bracket_gtd[high_pos] = gtd_new
            low_pos, high_pos = (0, 1) if bracket_f[0] <= bracket_f[1] else (1, 0)
        else:
            if abs(gtd_new) <= -c2 * gtd:
                # Wolfe conditions satisfied
                done = True
            elif gtd_new * (bracket[high_pos] - bracket[low_pos]) >= 0:
                # old high becomes new low
                bracket[high_pos] = bracket[low_pos]
                bracket_f[high_pos] = bracket_f[low_pos]
                bracket_g[high_pos] = bracket_g[low_pos]  # type: ignore[possibly-undefined]
                bracket_gtd[high_pos] = bracket_gtd[low_pos]

            # new point becomes new low
            bracket[low_pos] = t
            bracket_f[low_pos] = f_new
            bracket_g[low_pos] = g_new.clone(memory_format=torch.contiguous_format)  # type: ignore[possibly-undefined]
            bracket_gtd[low_pos] = gtd_new

    # return stuff
    t = bracket[low_pos]  # type: ignore[possibly-undefined]
    f_new = bracket_f[low_pos]
    g_new = bracket_g[low_pos]  # type: ignore[possibly-undefined]
    return f_new, g_new, t, ls_func_evals


[docs]class LBFGS(Optimizer): """Implements L-BFGS algorithm. Heavily inspired by `minFunc <https://www.cs.ubc.ca/~schmidtm/Software/minFunc.html>`_. .. warning:: This optimizer doesn't support per-parameter options and parameter groups (there can be only one). .. warning:: Right now all parameters have to be on a single device. This will be improved in the future. .. note:: This is a very memory intensive optimizer (it requires additional ``param_bytes * (history_size + 1)`` bytes). If it doesn't fit in memory try reducing the history size, or use a different algorithm. Args: params (iterable): iterable of parameters to optimize. Parameters must be real. lr (float): learning rate (default: 1) max_iter (int): maximal number of iterations per optimization step (default: 20) max_eval (int): maximal number of function evaluations per optimization step (default: max_iter * 1.25). tolerance_grad (float): termination tolerance on first order optimality (default: 1e-7). tolerance_change (float): termination tolerance on function value/parameter changes (default: 1e-9). history_size (int): update history size (default: 100). line_search_fn (str): either 'strong_wolfe' or None (default: None). """ def __init__( self, params: ParamsT, lr: float = 1, max_iter: int = 20, max_eval: Optional[int] = None, tolerance_grad: float = 1e-7, tolerance_change: float = 1e-9, history_size: int = 100, line_search_fn: Optional[str] = None, ): if max_eval is None: max_eval = max_iter * 5 // 4 defaults = dict( lr=lr, max_iter=max_iter, max_eval=max_eval, tolerance_grad=tolerance_grad, tolerance_change=tolerance_change, history_size=history_size, line_search_fn=line_search_fn, ) super().__init__(params, defaults) if len(self.param_groups) != 1: raise ValueError( "LBFGS doesn't support per-parameter options " "(parameter groups)" ) self._params = self.param_groups[0]["params"] self._numel_cache = None def _numel(self): if self._numel_cache is None: self._numel_cache = sum( 2 * p.numel() if torch.is_complex(p) else p.numel() for p in self._params ) return self._numel_cache def _gather_flat_grad(self): views = [] for p in self._params: if p.grad is None: view = p.new(p.numel()).zero_() elif p.grad.is_sparse: view = p.grad.to_dense().view(-1) else: view = p.grad.view(-1) if torch.is_complex(view): view = torch.view_as_real(view).view(-1) views.append(view) return torch.cat(views, 0) def _add_grad(self, step_size, update): offset = 0 for p in self._params: if torch.is_complex(p): p = torch.view_as_real(p) numel = p.numel() # view as to avoid deprecated pointwise semantics p.add_(update[offset : offset + numel].view_as(p), alpha=step_size) offset += numel assert offset == self._numel() def _clone_param(self): return [p.clone(memory_format=torch.contiguous_format) for p in self._params] def _set_param(self, params_data): for p, pdata in zip(self._params, params_data): p.copy_(pdata) def _directional_evaluate(self, closure, x, t, d): self._add_grad(t, d) loss = float(closure()) flat_grad = self._gather_flat_grad() self._set_param(x) return loss, flat_grad
[docs] @torch.no_grad() def step(self, closure): """Perform a single optimization step. Args: closure (Callable): A closure that reevaluates the model and returns the loss. """ assert len(self.param_groups) == 1 # Make sure the closure is always called with grad enabled closure = torch.enable_grad()(closure) group = self.param_groups[0] lr = group["lr"] max_iter = group["max_iter"] max_eval = group["max_eval"] tolerance_grad = group["tolerance_grad"] tolerance_change = group["tolerance_change"] line_search_fn = group["line_search_fn"] history_size = group["history_size"] # NOTE: LBFGS has only global state, but we register it as state for # the first param, because this helps with casting in load_state_dict state = self.state[self._params[0]] state.setdefault("func_evals", 0) state.setdefault("n_iter", 0) # evaluate initial f(x) and df/dx orig_loss = closure() loss = float(orig_loss) current_evals = 1 state["func_evals"] += 1 flat_grad = self._gather_flat_grad() opt_cond = flat_grad.abs().max() <= tolerance_grad # optimal condition if opt_cond: return orig_loss # tensors cached in state (for tracing) d = state.get("d") t = state.get("t") old_dirs = state.get("old_dirs") old_stps = state.get("old_stps") ro = state.get("ro") H_diag = state.get("H_diag") prev_flat_grad = state.get("prev_flat_grad") prev_loss = state.get("prev_loss") n_iter = 0 # optimize for a max of max_iter iterations while n_iter < max_iter: # keep track of nb of iterations n_iter += 1 state["n_iter"] += 1 ############################################################ # compute gradient descent direction ############################################################ if state["n_iter"] == 1: d = flat_grad.neg() old_dirs = [] old_stps = [] ro = [] H_diag = 1 else: # do lbfgs update (update memory) y = flat_grad.sub(prev_flat_grad) s = d.mul(t) ys = y.dot(s) # y*s if ys > 1e-10: # updating memory if len(old_dirs) == history_size: # shift history by one (limited-memory) old_dirs.pop(0) old_stps.pop(0) ro.pop(0) # store new direction/step old_dirs.append(y) old_stps.append(s) ro.append(1.0 / ys) # update scale of initial Hessian approximation H_diag = ys / y.dot(y) # (y*y) # compute the approximate (L-BFGS) inverse Hessian # multiplied by the gradient num_old = len(old_dirs) if "al" not in state: state["al"] = [None] * history_size al = state["al"] # iteration in L-BFGS loop collapsed to use just one buffer q = flat_grad.neg() for i in range(num_old - 1, -1, -1): al[i] = old_stps[i].dot(q) * ro[i] q.add_(old_dirs[i], alpha=-al[i]) # multiply by initial Hessian # r/d is the final direction d = r = torch.mul(q, H_diag) for i in range(num_old): be_i = old_dirs[i].dot(r) * ro[i] r.add_(old_stps[i], alpha=al[i] - be_i) if prev_flat_grad is None: prev_flat_grad = flat_grad.clone(memory_format=torch.contiguous_format) else: prev_flat_grad.copy_(flat_grad) prev_loss = loss ############################################################ # compute step length ############################################################ # reset initial guess for step size if state["n_iter"] == 1: t = min(1.0, 1.0 / flat_grad.abs().sum()) * lr else: t = lr # directional derivative gtd = flat_grad.dot(d) # g * d # directional derivative is below tolerance if gtd > -tolerance_change: break # optional line search: user function ls_func_evals = 0 if line_search_fn is not None: # perform line search, using user function if line_search_fn != "strong_wolfe": raise RuntimeError("only 'strong_wolfe' is supported") else: x_init = self._clone_param() def obj_func(x, t, d): return self._directional_evaluate(closure, x, t, d) loss, flat_grad, t, ls_func_evals = _strong_wolfe( obj_func, x_init, t, d, loss, flat_grad, gtd ) self._add_grad(t, d) opt_cond = flat_grad.abs().max() <= tolerance_grad else: # no line search, simply move with fixed-step self._add_grad(t, d) if n_iter != max_iter: # re-evaluate function only if not in last iteration # the reason we do this: in a stochastic setting, # no use to re-evaluate that function here with torch.enable_grad(): loss = float(closure()) flat_grad = self._gather_flat_grad() opt_cond = flat_grad.abs().max() <= tolerance_grad ls_func_evals = 1 # update func eval current_evals += ls_func_evals state["func_evals"] += ls_func_evals ############################################################ # check conditions ############################################################ if n_iter == max_iter: break if current_evals >= max_eval: break # optimal condition if opt_cond: break # lack of progress if d.mul(t).abs().max() <= tolerance_change: break if abs(loss - prev_loss) < tolerance_change: break state["d"] = d state["t"] = t state["old_dirs"] = old_dirs state["old_stps"] = old_stps state["ro"] = ro state["H_diag"] = H_diag state["prev_flat_grad"] = prev_flat_grad state["prev_loss"] = prev_loss return orig_loss

Docs

Access comprehensive developer documentation for PyTorch

View Docs

Tutorials

Get in-depth tutorials for beginners and advanced developers

View Tutorials

Resources

Find development resources and get your questions answered

View Resources