LinearCyclicalScheduler#
- class ignite.handlers.param_scheduler.LinearCyclicalScheduler(optimizer, param_name, start_value, end_value, cycle_size, cycle_mult=1.0, start_value_mult=1.0, end_value_mult=1.0, warmup_duration=0, save_history=False, param_group_index=None)[source]#
Linearly adjusts param value to ‘end_value’ for a half-cycle, then linearly adjusts it back to ‘start_value’ for a half-cycle.
- Parameters
optimizer (Optimizer) – torch optimizer or any object with attribute
param_groups
as a sequence.param_name (str) – name of optimizer’s parameter to update.
start_value (float) – value at start of cycle.
end_value (float) – value at the middle of the cycle.
cycle_size (int) – length of cycle.
cycle_mult (float) – ratio by which to change the cycle_size at the end of each cycle (default=1).
start_value_mult (float) – ratio by which to change the start value at the end of each cycle (default=1.0).
end_value_mult (float) – ratio by which to change the end value at the end of each cycle (default=1.0).
warmup_duration (int) – duration of warm-up to be applied before each cycle. Through this warm-up, the parameter starts from the last cycle’s end value and linearly goes to next cycle’s start value. Default is no cyclic warm-up.
save_history (bool) – whether to log the parameter values to engine.state.param_history, (default=False).
param_group_index (Optional[int]) – optimizer’s parameters group to use.
Note
If the scheduler is bound to an ‘ITERATION_*’ event, ‘cycle_size’ should usually be the number of batches in an epoch.
Examples
from collections import OrderedDict import torch from torch import nn, optim from ignite.engine import * from ignite.handlers import * from ignite.metrics import * from ignite.utils import * from ignite.contrib.metrics.regression import * from ignite.contrib.metrics import * # create default evaluator for doctests def eval_step(engine, batch): return batch default_evaluator = Engine(eval_step) # create default optimizer for doctests param_tensor = torch.zeros([1], requires_grad=True) default_optimizer = torch.optim.SGD([param_tensor], lr=0.1) # create default trainer for doctests # as handlers could be attached to the trainer, # each test must define his own trainer using `.. testsetup:` def get_default_trainer(): def train_step(engine, batch): return batch return Engine(train_step) # create default model for doctests default_model = nn.Sequential(OrderedDict([ ('base', nn.Linear(4, 2)), ('fc', nn.Linear(2, 1)) ])) manual_seed(666)
default_trainer = get_default_trainer() # Linearly increases the learning rate from 0.0 to 1.0 and back to 0.0 # over a cycle of 4 iterations scheduler = LinearCyclicalScheduler(default_optimizer, "lr", 0.0, 1.0, 4) default_trainer.add_event_handler(Events.ITERATION_STARTED, scheduler) @default_trainer.on(Events.ITERATION_COMPLETED) def print_lr(): print(default_optimizer.param_groups[0]["lr"]) default_trainer.run([0] * 9, max_epochs=1)
0.0 0.5 1.0 0.5 ...
default_trainer = get_default_trainer() optimizer = torch.optim.SGD( [ {"params": default_model.base.parameters(), "lr": 0.001}, {"params": default_model.fc.parameters(), "lr": 0.01}, ] ) # Linearly increases the learning rate from 0.0 to 1.0 and back to 0.0 # over a cycle of 4 iterations scheduler1 = LinearCyclicalScheduler(optimizer, "lr (base)", 0.0, 1.0, 4, param_group_index=0) # Linearly increases the learning rate from 0.0 to 0.1 and back to 0.0 # over a cycle of 4 iterations scheduler2 = LinearCyclicalScheduler(optimizer, "lr (fc)", 0.0, 0.1, 4, param_group_index=1) default_trainer.add_event_handler(Events.ITERATION_STARTED, scheduler1) default_trainer.add_event_handler(Events.ITERATION_STARTED, scheduler2) @default_trainer.on(Events.ITERATION_COMPLETED) def print_lr(): print(optimizer.param_groups[0]["lr (base)"], optimizer.param_groups[1]["lr (fc)"]) default_trainer.run([0] * 9, max_epochs=1)
0.0 0.0 0.5 0.05 1.0 0.1 0.5 0.05 ...
New in version 0.4.5.
Changed in version 0.4.13: Added cyclic warm-up to the scheduler using
warmup_duration
.Methods
Method to get current optimizer's parameter value