Source code for espnet2.schedulers.abs_scheduler

from abc import ABC, abstractmethod

import torch.optim.lr_scheduler as L


[docs]class AbsScheduler(ABC):
[docs] @abstractmethod def step(self, epoch: int = None): pass
[docs] @abstractmethod def state_dict(self): pass
[docs] @abstractmethod def load_state_dict(self, state): pass
# If you need to define custom scheduler, please inherit these classes
[docs]class AbsBatchStepScheduler(AbsScheduler):
[docs] @abstractmethod def step(self, epoch: int = None): pass
[docs] @abstractmethod def state_dict(self): pass
[docs] @abstractmethod def load_state_dict(self, state): pass
[docs]class AbsEpochStepScheduler(AbsScheduler):
[docs] @abstractmethod def step(self, epoch: int = None): pass
[docs] @abstractmethod def state_dict(self): pass
[docs] @abstractmethod def load_state_dict(self, state): pass
[docs]class AbsValEpochStepScheduler(AbsEpochStepScheduler):
[docs] @abstractmethod def step(self, val, epoch: int = None): pass
[docs] @abstractmethod def state_dict(self): pass
[docs] @abstractmethod def load_state_dict(self, state): pass
# Create alias type to check the type # Note(kamo): Currently PyTorch doesn't provide the base class # to judge these classes. AbsValEpochStepScheduler.register(L.ReduceLROnPlateau) for s in [ L.ReduceLROnPlateau, L.LambdaLR, L.StepLR, L.MultiStepLR, L.MultiStepLR, L.ExponentialLR, L.CosineAnnealingLR, ]: AbsEpochStepScheduler.register(s) AbsBatchStepScheduler.register(L.CyclicLR) for s in [ L.OneCycleLR, L.CosineAnnealingWarmRestarts, ]: AbsBatchStepScheduler.register(s)