Source code for espnet2.schedulers.cosine_anneal_warmup_restart

# code from: https://github.com/katsura-jp/pytorch-cosine-annealing-with-
# warmup/blob/master/cosine_annealing_warmup/scheduler.py
# original paper: https://arxiv.org/pdf/1608.03983.pdf
# Similar to PyTorch official CosineAnnealWarmRestarts,
# but additionally features warmup function and scaling of max lr for each
# restart
import math

import torch
from torch.optim.lr_scheduler import _LRScheduler

from espnet2.schedulers.abs_scheduler import AbsBatchStepScheduler


[docs]class CosineAnnealingWarmupRestarts(_LRScheduler, AbsBatchStepScheduler): """ optimizer (Optimizer): Wrapped optimizer. first_cycle_steps (int): First cycle step size. cycle_mult(float): Cycle steps magnification. Default: -1. max_lr(float): First cycle's max learning rate. Default: 0.1. min_lr(float): Min learning rate. Default: 0.001. warmup_steps(int): Linear warmup step size. Default: 0. gamma(float): Decrease rate of max learning rate by cycle. Default: 1. last_epoch (int): The index of last epoch. Default: -1. """ def __init__( self, optimizer: torch.optim.Optimizer, first_cycle_steps: int, cycle_mult: float = 1.0, max_lr: float = 0.1, min_lr: float = 0.001, warmup_steps: int = 0, gamma: float = 1.0, last_epoch: int = -1, ): assert warmup_steps < first_cycle_steps self.first_cycle_steps = first_cycle_steps # first cycle step size self.cycle_mult = cycle_mult # cycle steps magnification self.base_max_lr = max_lr # first max learning rate self.max_lr = max_lr # max learning rate in the current cycle self.min_lr = min_lr # min learning rate self.warmup_steps = warmup_steps # warmup step size self.gamma = gamma # decrease rate of max learning rate by cycle self.cur_cycle_steps = first_cycle_steps # first cycle step size self.cycle = 0 # cycle count self.step_in_cycle = last_epoch # step size of the current cycle super(CosineAnnealingWarmupRestarts, self).__init__(optimizer, last_epoch) # set learning rate min_lr self.init_lr()
[docs] def init_lr(self): self.base_lrs = [] for param_group in self.optimizer.param_groups: param_group["lr"] = self.min_lr self.base_lrs.append(self.min_lr)
[docs] def get_lr(self): if self.step_in_cycle == -1: return self.base_lrs elif self.step_in_cycle < self.warmup_steps: return [ (self.max_lr - base_lr) * self.step_in_cycle / self.warmup_steps + base_lr for base_lr in self.base_lrs ] else: return [ base_lr + (self.max_lr - base_lr) * ( 1 + math.cos( math.pi * (self.step_in_cycle - self.warmup_steps) / (self.cur_cycle_steps - self.warmup_steps) ) ) / 2 for base_lr in self.base_lrs ]
[docs] def step(self, epoch=None): if epoch is None: epoch = self.last_epoch + 1 self.step_in_cycle = self.step_in_cycle + 1 if self.step_in_cycle >= self.cur_cycle_steps: self.cycle += 1 self.step_in_cycle = self.step_in_cycle - self.cur_cycle_steps self.cur_cycle_steps = ( int((self.cur_cycle_steps - self.warmup_steps) * self.cycle_mult) + self.warmup_steps ) else: if epoch >= self.first_cycle_steps: if self.cycle_mult == 1.0: self.step_in_cycle = epoch % self.first_cycle_steps self.cycle = epoch // self.first_cycle_steps else: n = int( math.log( ( epoch / self.first_cycle_steps * (self.cycle_mult - 1) + 1 ), self.cycle_mult, ) ) self.cycle = n self.step_in_cycle = epoch - int( self.first_cycle_steps * (self.cycle_mult**n - 1) / (self.cycle_mult - 1) ) self.cur_cycle_steps = self.first_cycle_steps * self.cycle_mult ** ( n ) else: self.cur_cycle_steps = self.first_cycle_steps self.step_in_cycle = epoch self.max_lr = self.base_max_lr * (self.gamma**self.cycle) self.last_epoch = math.floor(epoch) for param_group, lr in zip(self.optimizer.param_groups, self.get_lr()): param_group["lr"] = lr