[docs]classFastaiLRFinder:"""Learning rate finder handler for supervised trainers. While attached, the handler increases the learning rate in between two boundaries in a linear or exponential manner. It provides valuable information on how well the network can be trained over a range of learning rates and what can be an optimal learning rate. Examples: .. code-block:: python from ignite.contrib.handlers import FastaiLRFinder trainer = ... model = ... optimizer = ... lr_finder = FastaiLRFinder() to_save = {"model": model, "optimizer": optimizer} with lr_finder.attach(trainer, to_save=to_save) as trainer_with_lr_finder: trainer_with_lr_finder.run(dataloader) # Get lr_finder results lr_finder.get_results() # Plot lr_finder results (requires matplotlib) lr_finder.plot() # get lr_finder suggestion for lr lr_finder.lr_suggestion() Note: When context manager is exited all LR finder's handlers are removed. Note: Please, also keep in mind that all other handlers attached the trainer will be executed during LR finder's run. Note: This class may require `matplotlib` package to be installed to plot learning rate range test: .. code-block:: bash pip install matplotlib References: Cyclical Learning Rates for Training Neural Networks: https://arxiv.org/abs/1506.01186 fastai/lr_find: https://github.com/fastai/fastai """def__init__(self):self._diverge_flag=Falseself._history=Noneself._best_loss=Noneself._lr_schedule=Noneself.logger=logging.getLogger(__name__)def_run(self,trainer,optimizer,output_transform,num_iter,end_lr,step_mode,smooth_f,diverge_th):self._history={"lr":[],"loss":[]}self._best_loss=Noneself._diverge_flag=False# attach LRScheduler to trainer.ifnum_iterisNone:num_iter=trainer.state.epoch_length*trainer.state.max_epochselse:max_iter=trainer.state.epoch_length*trainer.state.max_epochsifnum_iter>max_iter:warnings.warn("Desired num_iter {} is unreachable with the current run setup of {} iteration ""({} epochs)".format(num_iter,max_iter,trainer.state.max_epochs),UserWarning,)ifnottrainer.has_event_handler(self._reached_num_iterations):trainer.add_event_handler(Events.ITERATION_COMPLETED,self._reached_num_iterations,num_iter)# attach loss and lr loggingifnottrainer.has_event_handler(self._log_lr_and_loss):trainer.add_event_handler(Events.ITERATION_COMPLETED,self._log_lr_and_loss,output_transform,smooth_f,diverge_th)self.logger.debug("Running LR finder for {} iterations".format(num_iter))# Initialize the proper learning rate policyifstep_mode.lower()=="exp":self._lr_schedule=LRScheduler(_ExponentialLR(optimizer,end_lr,num_iter))else:start_lr=optimizer.param_groups[0]["lr"]self._lr_schedule=PiecewiseLinear(optimizer,param_name="lr",milestones_values=[(0,start_lr),(num_iter,end_lr)])ifnottrainer.has_event_handler(self._lr_schedule):trainer.add_event_handler(Events.ITERATION_COMPLETED,self._lr_schedule,num_iter)def_reset(self,trainer):self.logger.debug("Completed LR finder run")trainer.remove_event_handler(self._lr_schedule,Events.ITERATION_COMPLETED)trainer.remove_event_handler(self._log_lr_and_loss,Events.ITERATION_COMPLETED)trainer.remove_event_handler(self._reached_num_iterations,Events.ITERATION_COMPLETED)def_log_lr_and_loss(self,trainer,output_transform,smooth_f,diverge_th):output=trainer.state.outputloss=output_transform(output)lr=self._lr_schedule.get_param()self._history["lr"].append(lr)iftrainer.state.iteration==1:self._best_loss=losselse:ifsmooth_f>0:loss=smooth_f*loss+(1-smooth_f)*self._history["loss"][-1]ifloss<self._best_loss:self._best_loss=lossself._history["loss"].append(loss)# Check if the loss has diverged; if it has, stop the trainerifself._history["loss"][-1]>diverge_th*self._best_loss:self._diverge_flag=Trueself.logger.info("Stopping early, the loss has diverged")trainer.terminate()def_reached_num_iterations(self,trainer,num_iter):iftrainer.state.iteration>num_iter:trainer.terminate()def_warning(self,_):ifnotself._diverge_flag:warnings.warn("Run completed without loss diverging, increase end_lr, decrease diverge_th or look"" at lr_finder.plot()",UserWarning,)def_detach(self,trainer):""" Detaches lr_finder from trainer. Args: trainer: the trainer to detach form. """iftrainer.has_event_handler(self._run,Events.STARTED):trainer.remove_event_handler(self._run,Events.STARTED)iftrainer.has_event_handler(self._warning,Events.COMPLETED):trainer.remove_event_handler(self._warning,Events.COMPLETED)iftrainer.has_event_handler(self._reset,Events.COMPLETED):trainer.remove_event_handler(self._reset,Events.COMPLETED)
[docs]defget_results(self):""" Returns: dictionary with loss and lr logs fromm the previous run """returnself._history
[docs]defplot(self,skip_start=10,skip_end=5,log_lr=True):"""Plots the learning rate range test. This method requires `matplotlib` package to be installed: .. code-block:: bash pip install matplotlib Args: skip_start (int, optional): number of batches to trim from the start. Default: 10. skip_end (int, optional): number of batches to trim from the start. Default: 5. log_lr (bool, optional): True to plot the learning rate in a logarithmic scale; otherwise, plotted in a linear scale. Default: True. """try:frommatplotlibimportpyplotaspltexceptImportError:raiseRuntimeError("This method requires matplotlib to be installed. ""Please install it with command: \n pip install matplotlib")ifself._historyisNone:raiseRuntimeError("learning rate finder didn't run yet so results can't be plotted")ifskip_start<0:raiseValueError("skip_start cannot be negative")ifskip_end<0:raiseValueError("skip_end cannot be negative")# Get the data to plot from the history dictionary. Also, handle skip_end=0# properly so the behaviour is the expectedlrs=self._history["lr"]losses=self._history["loss"]ifskip_end==0:lrs=lrs[skip_start:]losses=losses[skip_start:]else:lrs=lrs[skip_start:-skip_end]losses=losses[skip_start:-skip_end]# Plot loss as a function of the learning rateplt.plot(lrs,losses)iflog_lr:plt.xscale("log")plt.xlabel("Learning rate")plt.ylabel("Loss")plt.show()
[docs]deflr_suggestion(self):""" Returns: learning rate at the minimum numerical gradient """ifself._historyisNone:raiseRuntimeError("learning rate finder didn't run yet so lr_suggestion can't be returned")loss=self._history["loss"]grads=torch.tensor([loss[i]-loss[i-1]foriinrange(1,len(loss))])min_grad_idx=grads.argmin()+1returnself._history["lr"][int(min_grad_idx)]
[docs]@contextlib.contextmanagerdefattach(self,trainer,to_save,output_transform=lambdaoutput:output,num_iter=None,end_lr=10.0,step_mode="exp",smooth_f=0.05,diverge_th=5.0,):"""Attaches lr_finder to a given trainer. It also resets model and optimizer at the end of the run. Usage: .. code-block:: python to_save = {"model": model, "optimizer": optimizer} with lr_finder.attach(trainer, to_save=to_save) as trainer_with_lr_finder: trainer_with_lr_finder.run(dataloader)` Args: trainer (Engine): lr_finder is attached to this trainer. Please, keep in mind that all attached handlers will be executed. to_save (Mapping): dictionary with optimizer and other objects that needs to be restored after running the LR finder. For example, `to_save={'optimizer': optimizer, 'model': model}`. All objects should implement `state_dict` and `load_state_dict` methods. output_transform (callable, optional): function that transforms the trainer's `state.output` after each iteration. It must return the loss of that iteration. num_iter (int, optional): number of iterations for lr schedule between base lr and end_lr. Default, it will run for `trainer.state.epoch_length * trainer.state.max_epochs`. end_lr (float, optional): upper bound for lr search. Default, 10.0. step_mode (str, optional): "exp" or "linear", which way should the lr be increased from optimizer's initial lr to `end_lr`. Default, "exp". smooth_f (float, optional): loss smoothing factor in range `[0, 1)`. Default, 0.05 diverge_th (float, optional): Used for stopping the search when `current loss > diverge_th * best_loss`. Default, 5.0. Notes: lr_finder cannot be attached to more than one trainer at a time Returns: trainer_with_lr_finder: trainer used for finding the lr """ifnotisinstance(to_save,Mapping):raiseTypeError("Argument to_save should be a mapping, but given {}".format(type(to_save)))Checkpoint._check_objects(to_save,"state_dict")Checkpoint._check_objects(to_save,"load_state_dict")if"optimizer"notinto_save:raiseValueError("Mapping to_save should contain 'optimizer' key")ifnotisinstance(to_save["optimizer"],torch.optim.Optimizer):raiseValueError("Object to_save['optimizer'] should be torch optimizer, but given {}".format(type(to_save["optimizer"])))ifsmooth_f<0orsmooth_f>=1:raiseValueError("smooth_f is outside the range [0, 1]")ifdiverge_th<1:raiseValueError("diverge_th should be larger than 1")ifstep_modenotin["exp","linear"]:raiseValueError("step_mode should be 'exp' or 'linear', but given {}".format(step_mode))ifnum_iterisnotNoneand(notisinstance(num_iter,int)ornum_iter<=0):raiseValueError("if provided, num_iter should be a positive integer, but given {}".format(num_iter))# store to_savewithtempfile.TemporaryDirectory()astmpdirname:obj={k:o.state_dict()fork,ointo_save.items()}# add trainerobj["trainer"]=trainer.state_dict()cache_filepath=Path(tmpdirname)/"ignite_lr_finder_cache.pt"torch.save(obj,cache_filepath.as_posix())optimizer=to_save["optimizer"]# Attach handlersifnottrainer.has_event_handler(self._run):trainer.add_event_handler(Events.STARTED,self._run,optimizer,output_transform,num_iter,end_lr,step_mode,smooth_f,diverge_th,)ifnottrainer.has_event_handler(self._warning):trainer.add_event_handler(Events.COMPLETED,self._warning)ifnottrainer.has_event_handler(self._reset):trainer.add_event_handler(Events.COMPLETED,self._reset)yieldtrainerself._detach(trainer)# restore to_save and reset trainer's stateobj=torch.load(cache_filepath.as_posix())trainer.load_state_dict(obj["trainer"])fork,oinobj.items():ifkinto_save:to_save[k].load_state_dict(o)
class_ExponentialLR(_LRScheduler):"""Exponentially increases the learning rate between two boundaries over a number of iterations. Arguments: optimizer (torch.optim.Optimizer): wrapped optimizer. end_lr (float, optional): the initial learning rate which is the lower boundary of the test. Default: 10. num_iter (int, optional): the number of iterations over which the test occurs. Default: 100. last_epoch (int): the index of last epoch. Default: -1. """def__init__(self,optimizer,end_lr,num_iter,last_epoch=-1):self.end_lr=end_lrself.num_iter=num_itersuper(_ExponentialLR,self).__init__(optimizer,last_epoch)defget_lr(self):curr_iter=self.last_epoch+1r=curr_iter/self.num_iterreturn[base_lr*(self.end_lr/base_lr)**rforbase_lrinself.base_lrs]