Shortcuts

CosineAnnealingWarmRestarts

class torch.optim.lr_scheduler.CosineAnnealingWarmRestarts(optimizer, T_0, T_mult=1, eta_min=0, last_epoch=-1, verbose='deprecated')[source]

Set the learning rate of each parameter group using a cosine annealing schedule, where ηmax\eta_{max} is set to the initial lr, TcurT_{cur} is the number of epochs since the last restart and TiT_{i} is the number of epochs between two warm restarts in SGDR:

ηt=ηmin+12(ηmaxηmin)(1+cos(TcurTiπ))\eta_t = \eta_{min} + \frac{1}{2}(\eta_{max} - \eta_{min})\left(1 + \cos\left(\frac{T_{cur}}{T_{i}}\pi\right)\right)

When Tcur=TiT_{cur}=T_{i}, set ηt=ηmin\eta_t = \eta_{min}. When Tcur=0T_{cur}=0 after restart, set ηt=ηmax\eta_t=\eta_{max}.

It has been proposed in SGDR: Stochastic Gradient Descent with Warm Restarts.

Parameters
  • optimizer (Optimizer) – Wrapped optimizer.

  • T_0 (int) – Number of iterations until the first restart.

  • T_mult (int, optional) – A factor by which TiT_{i} increases after a restart. Default: 1.

  • eta_min (float, optional) – Minimum learning rate. Default: 0.

  • last_epoch (int, optional) – The index of the last epoch. Default: -1.

  • verbose (bool | str) –

    If True, prints a message to stdout for each update. Default: False.

    Deprecated since version 2.2: verbose is deprecated. Please use get_last_lr() to access the learning rate.

get_last_lr()

Return last computed learning rate by current scheduler.

Return type

List[float]

load_state_dict(state_dict)

Loads the schedulers state.

Parameters

state_dict (dict) – scheduler state. Should be an object returned from a call to state_dict().

print_lr(is_verbose, group, lr, epoch=None)

Display the current learning rate.

Deprecated since version 2.4: print_lr() is deprecated. Please use get_last_lr() to access the learning rate.

state_dict()

Returns the state of the scheduler as a dict.

It contains an entry for every variable in self.__dict__ which is not the optimizer.

step(epoch=None)[source]

Step could be called after every batch update

Example

>>> scheduler = CosineAnnealingWarmRestarts(optimizer, T_0, T_mult)
>>> iters = len(dataloader)
>>> for epoch in range(20):
>>>     for i, sample in enumerate(dataloader):
>>>         inputs, labels = sample['inputs'], sample['labels']
>>>         optimizer.zero_grad()
>>>         outputs = net(inputs)
>>>         loss = criterion(outputs, labels)
>>>         loss.backward()
>>>         optimizer.step()
>>>         scheduler.step(epoch + i / iters)

This function can be called in an interleaved way.

Example

>>> scheduler = CosineAnnealingWarmRestarts(optimizer, T_0, T_mult)
>>> for epoch in range(20):
>>>     scheduler.step()
>>> scheduler.step(26)
>>> scheduler.step() # scheduler.step(27), instead of scheduler(20)

Docs

Access comprehensive developer documentation for PyTorch

View Docs

Tutorials

Get in-depth tutorials for beginners and advanced developers

View Tutorials

Resources

Find development resources and get your questions answered

View Resources