"""ReduceLROnPlateau (with Warm up) learning rate scheduler module."""
from typing import Union
import torch
from torch import inf
from typeguard import check_argument_types
from espnet2.schedulers.abs_scheduler import (
AbsBatchStepScheduler,
AbsValEpochStepScheduler,
)
[docs]class WarmupReduceLROnPlateau(AbsBatchStepScheduler, AbsValEpochStepScheduler):
"""The WarmupReduceLROnPlateau scheduler.
This scheduler is the combination of WarmupLR and ReduceLROnPlateau:
WarmupLR:
lr = optimizer.lr * warmup_step ** 0.5
* min(step ** -0.5, step * warmup_step ** -1.5)
WarmupReduceLROnPlateau:
if step <= warmup_step:
lr = optimizer.lr * warmup_step ** 0.5
* min(step ** -0.5, step * warmup_step ** -1.5)
else:
lr = (
optimizer.lr * factor
if no improvement for a 'patience' number of epochs
else optimizer.lr
)
Note that the maximum lr equals to optimizer.lr in this scheduler.
"""
def __init__(
self,
optimizer: torch.optim.Optimizer,
# for WarmupLR
warmup_steps: Union[int, float] = 25000,
# for ReduceLROnPlateau
mode="min",
factor=0.1,
patience=10,
threshold=1e-4,
threshold_mode="rel",
cooldown=0,
min_lr=0,
eps=1e-8,
verbose=False,
):
assert check_argument_types()
self.warmup_steps = warmup_steps
self.step_num = 0
self.lr_scale = warmup_steps**-1
# Initialize base learning rates
for group in optimizer.param_groups:
if "initial_lr" not in group:
group.setdefault("initial_lr", group["lr"])
self.base_lrs = [group["initial_lr"] for group in optimizer.param_groups]
if factor >= 1.0:
raise ValueError("Factor should be < 1.0.")
self.factor = factor
# Attach optimizer
self.optimizer = optimizer
if isinstance(min_lr, list) or isinstance(min_lr, tuple):
if len(min_lr) != len(optimizer.param_groups):
raise ValueError(
"expected {} min_lrs, got {}".format(
len(optimizer.param_groups), len(min_lr)
)
)
self.min_lrs = list(min_lr)
else:
self.min_lrs = [min_lr] * len(optimizer.param_groups)
self.patience = patience
self.verbose = verbose
self.cooldown = cooldown
self.cooldown_counter = 0
self.mode = mode
self.threshold = threshold
self.threshold_mode = threshold_mode
self.best = None
self.num_bad_epochs = None
self.mode_worse = None # the worse value for the chosen mode
self.eps = eps
self.last_epoch = 0
self._init_is_better(
mode=mode, threshold=threshold, threshold_mode=threshold_mode
)
self._reset()
def __repr__(self):
return (
f"{self.__class__.__name__}(warmup_steps={self.warmup_steps}, "
f"mode={self.mode}, factor={self.factor}, patience={self.patience}"
)
[docs] def step(self, metrics=None, epoch=None):
if metrics is None:
# WarmupLR
self.step_num += 1
if self.step_num <= self.warmup_steps:
for param_group, lr in zip(self.optimizer.param_groups, self.base_lrs):
param_group["lr"] = lr * self.lr_scale * self.step_num
else:
# ReduceLROnPlateau
self._step_reducelronplateau(metrics, epoch=epoch)
def _reset(self):
"""Resets num_bad_epochs counter and cooldown counter."""
self.best = self.mode_worse
self.cooldown_counter = 0
self.num_bad_epochs = 0
def _step_reducelronplateau(self, metrics=None, epoch=None):
# convert `metrics` to float, in case it's a zero-dim Tensor
current = float(metrics)
if epoch is None:
epoch = self.last_epoch + 1
self.last_epoch = epoch
if self.is_better(current, self.best):
self.best = current
self.num_bad_epochs = 0
else:
self.num_bad_epochs += 1
if self.in_cooldown:
self.cooldown_counter -= 1
self.num_bad_epochs = 0 # ignore any bad epochs in cooldown
if self.num_bad_epochs > self.patience:
self._reduce_lr(epoch)
self.cooldown_counter = self.cooldown
self.num_bad_epochs = 0
self._last_lr = [group["lr"] for group in self.optimizer.param_groups]
def _reduce_lr(self, epoch):
for i, param_group in enumerate(self.optimizer.param_groups):
old_lr = float(param_group["lr"])
new_lr = max(old_lr * self.factor, self.min_lrs[i])
if old_lr - new_lr > self.eps:
param_group["lr"] = new_lr
if self.verbose:
epoch_str = ("%.2f" if isinstance(epoch, float) else "%.5d") % epoch
print(
"Epoch {}: reducing learning rate"
" of group {} to {:.4e}.".format(epoch_str, i, new_lr)
)
@property
def in_cooldown(self):
return self.cooldown_counter > 0
[docs] def is_better(self, a, best):
if self.mode == "min" and self.threshold_mode == "rel":
rel_epsilon = 1.0 - self.threshold
return a < best * rel_epsilon
elif self.mode == "min" and self.threshold_mode == "abs":
return a < best - self.threshold
elif self.mode == "max" and self.threshold_mode == "rel":
rel_epsilon = self.threshold + 1.0
return a > best * rel_epsilon
else: # mode == 'max' and epsilon_mode == 'abs':
return a > best + self.threshold
def _init_is_better(self, mode, threshold, threshold_mode):
if mode not in {"min", "max"}:
raise ValueError("mode " + mode + " is unknown!")
if threshold_mode not in {"rel", "abs"}:
raise ValueError("threshold mode " + threshold_mode + " is unknown!")
if mode == "min":
self.mode_worse = inf
else: # mode == 'max':
self.mode_worse = -inf
self.mode = mode
self.threshold = threshold
self.threshold_mode = threshold_mode
[docs] def state_dict(self):
return {
key: value for key, value in self.__dict__.items() if key != "optimizer"
}
[docs] def load_state_dict(self, state_dict):
self.__dict__.update(state_dict)
self._init_is_better(
mode=self.mode, threshold=self.threshold, threshold_mode=self.threshold_mode
)