Shortcuts

LinearCyclicalScheduler#

class ignite.handlers.param_scheduler.LinearCyclicalScheduler(*args, monotonic=False, **kwagrs)[source]#

Linearly adjusts param value to ‘end_value’ for a half-cycle, then linearly adjusts it back to ‘start_value’ for a half-cycle.

Parameters
  • optimizer – torch optimizer or any object with attribute param_groups as a sequence.

  • param_name – name of optimizer’s parameter to update.

  • start_value – value at start of cycle.

  • end_value – value at the middle of the cycle.

  • cycle_size – length of cycle.

  • cycle_mult – ratio by which to change the cycle_size at the end of each cycle (default=1).

  • start_value_mult – ratio by which to change the start value at the end of each cycle (default=1.0).

  • end_value_mult – ratio by which to change the end value at the end of each cycle (default=1.0).

  • warmup_duration – duration of warm-up to be applied before each cycle. Through this warm-up, the parameter starts from the last cycle’s end value and linearly goes to next cycle’s start value. Default is no cyclic warm-up.

  • save_history – whether to log the parameter values to engine.state.param_history, (default=False).

  • param_group_index – optimizer’s parameters group to use.

  • monotonic (bool) – whether to schedule only one half of the cycle: descending or ascending. If True, this argument can not be used together with warmup_duration. (default=False).

  • args (Any) –

  • kwagrs (Any) –

Note

If the scheduler is bound to an ‘ITERATION_*’ event, ‘cycle_size’ should usually be the number of batches in an epoch.

Examples

from collections import OrderedDict

import torch
from torch import nn, optim

from ignite.engine import *
from ignite.handlers import *
from ignite.metrics import *
from ignite.metrics.regression import *
from ignite.utils import *

# create default evaluator for doctests

def eval_step(engine, batch):
    return batch

default_evaluator = Engine(eval_step)

# create default optimizer for doctests

param_tensor = torch.zeros([1], requires_grad=True)
default_optimizer = torch.optim.SGD([param_tensor], lr=0.1)

# create default trainer for doctests
# as handlers could be attached to the trainer,
# each test must define his own trainer using `.. testsetup:`

def get_default_trainer():

    def train_step(engine, batch):
        return batch

    return Engine(train_step)

# create default model for doctests

default_model = nn.Sequential(OrderedDict([
    ('base', nn.Linear(4, 2)),
    ('fc', nn.Linear(2, 1))
]))

manual_seed(666)
default_trainer = get_default_trainer()

# Linearly increases the learning rate from 0.0 to 1.0 and back to 0.0
# over a cycle of 4 iterations
scheduler = LinearCyclicalScheduler(default_optimizer, "lr", 0.0, 1.0, 4)

default_trainer.add_event_handler(Events.ITERATION_STARTED, scheduler)

@default_trainer.on(Events.ITERATION_COMPLETED)
def print_lr():
    print(default_optimizer.param_groups[0]["lr"])

default_trainer.run([0] * 9, max_epochs=1)
0.0
0.5
1.0
0.5
...
default_trainer = get_default_trainer()

optimizer = torch.optim.SGD(
    [
        {"params": default_model.base.parameters(), "lr": 0.001},
        {"params": default_model.fc.parameters(), "lr": 0.01},
    ]
)

# Linearly increases the learning rate from 0.0 to 1.0 and back to 0.0
# over a cycle of 4 iterations
scheduler1 = LinearCyclicalScheduler(optimizer, "lr (base)", 0.0, 1.0, 4, param_group_index=0)

# Linearly increases the learning rate from 0.0 to 0.1 and back to 0.0
# over a cycle of 4 iterations
scheduler2 = LinearCyclicalScheduler(optimizer, "lr (fc)", 0.0, 0.1, 4, param_group_index=1)

default_trainer.add_event_handler(Events.ITERATION_STARTED, scheduler1)
default_trainer.add_event_handler(Events.ITERATION_STARTED, scheduler2)

@default_trainer.on(Events.ITERATION_COMPLETED)
def print_lr():
    print(optimizer.param_groups[0]["lr (base)"],
          optimizer.param_groups[1]["lr (fc)"])

default_trainer.run([0] * 9, max_epochs=1)
0.0 0.0
0.5 0.05
1.0 0.1
0.5 0.05
...

New in version 0.4.5.

Changed in version 0.4.13: Added cyclic warm-up to the scheduler using warmup_duration.

Changed in version 0.5.0: Added monotonic argument.

Methods

get_param

Method to get current optimizer's parameter value

get_param()[source]#

Method to get current optimizer’s parameter value

Return type

float