1010from typing import Any , cast , Dict , List , Mapping , Optional , Sequence , Tuple , Type , Union
1111
1212import torch
13- from torch .optim .lr_scheduler import ReduceLROnPlateau
13+ from torch .optim .lr_scheduler import CosineAnnealingWarmRestarts , ReduceLROnPlateau
1414from torch .optim .optimizer import Optimizer
1515
1616# https://github.com/pytorch/ignite/issues/2773
@@ -792,6 +792,61 @@ def simulate_values( # type: ignore[override]
792792 return output
793793
794794
795+ class _CosineAnnealingWarmRestarts :
796+ def __init__ (self , lr_scheduler : CosineAnnealingWarmRestarts ):
797+ self ._lr_scheduler = lr_scheduler
798+
799+ @property
800+ def last_epoch (self ) -> int :
801+ return self ._lr_scheduler .last_epoch
802+
803+ @last_epoch .setter
804+ def last_epoch (self , value : int ) -> None :
805+ self ._lr_scheduler .last_epoch = value
806+
807+ @property
808+ def optimizer (self ) -> torch .optim .Optimizer :
809+ return self ._lr_scheduler .optimizer
810+
811+ def get_lr (self , epoch : Optional [int ] = None ) -> List [float ]:
812+ # TODO: Remove this workaround when pytorch has fixed wrong type hints:
813+ # https://github.com/pytorch/pytorch/pull/102067
814+ # Replace below T_mult -> self._lr_scheduler.T_mult
815+ # Replace below eta_min -> self._lr_scheduler.eta_min
816+ T_mult = cast (int , self ._lr_scheduler .T_mult )
817+ eta_min = cast (float , self ._lr_scheduler .eta_min )
818+
819+ if epoch is None and self .last_epoch < 0 :
820+ epoch = 0
821+ if epoch is None :
822+ epoch = self .last_epoch + 1
823+ self ._lr_scheduler .T_cur = self ._lr_scheduler .T_cur + 1
824+ if self ._lr_scheduler .T_cur >= self ._lr_scheduler .T_i :
825+ self ._lr_scheduler .T_cur = self ._lr_scheduler .T_cur - self ._lr_scheduler .T_i
826+ self ._lr_scheduler .T_i = self ._lr_scheduler .T_i * T_mult
827+ else :
828+ if epoch < 0 :
829+ raise ValueError ("Expected non-negative epoch, but got {}" .format (epoch ))
830+ if epoch >= self ._lr_scheduler .T_0 :
831+ if T_mult == 1 :
832+ self ._lr_scheduler .T_cur = epoch % self ._lr_scheduler .T_0
833+ else :
834+ n = int (math .log ((epoch / self ._lr_scheduler .T_0 * (T_mult - 1 ) + 1 ), T_mult ))
835+ self ._lr_scheduler .T_cur = epoch - self ._lr_scheduler .T_0 * (T_mult ** n - 1 ) / (T_mult - 1 )
836+ self ._lr_scheduler .T_i = self ._lr_scheduler .T_0 * T_mult ** n
837+ else :
838+ self ._lr_scheduler .T_i = self ._lr_scheduler .T_0
839+ self ._lr_scheduler .T_cur = epoch
840+
841+ self .last_epoch = math .floor (epoch )
842+
843+ return [
844+ eta_min
845+ + (base_lr - eta_min ) * (1 + math .cos (math .pi * self ._lr_scheduler .T_cur / self ._lr_scheduler .T_i )) / 2
846+ for base_lr in self ._lr_scheduler .base_lrs
847+ ]
848+
849+
795850class LRScheduler (ParamScheduler ):
796851 """A wrapper class to call `torch.optim.lr_scheduler` objects as `ignite` handlers.
797852
@@ -853,7 +908,10 @@ def __init__(
853908 f"but given { type (lr_scheduler )} "
854909 )
855910
856- self .lr_scheduler = lr_scheduler
911+ self .lr_scheduler : Union [PyTorchLRScheduler , _CosineAnnealingWarmRestarts ] = lr_scheduler
912+ if isinstance (lr_scheduler , CosineAnnealingWarmRestarts ):
913+ self .lr_scheduler = _CosineAnnealingWarmRestarts (lr_scheduler )
914+
857915 super (LRScheduler , self ).__init__ (
858916 optimizer = self .lr_scheduler .optimizer ,
859917 param_name = "lr" ,
@@ -863,7 +921,7 @@ def __init__(
863921 warnings .warn (
864922 "Please make sure to attach scheduler to Events.ITERATION_COMPLETED "
865923 "instead of Events.ITERATION_STARTED to make sure to use "
866- "the first lr value from the optimizer, otherwise it is will be skipped"
924+ "the first lr value from the optimizer, otherwise it will be skipped"
867925 )
868926 self .lr_scheduler .last_epoch += 1
869927
@@ -876,9 +934,9 @@ def __call__(self, engine: Optional[Engine], name: Optional[str] = None) -> None
876934 def get_param (self ) -> Union [float , List [float ]]:
877935 """Method to get current optimizer's parameter value"""
878936 # Emulate context manager for pytorch>=1.4
879- self .lr_scheduler ._get_lr_called_within_step = True # type: ignore[attr-defined ]
937+ self .lr_scheduler ._get_lr_called_within_step = True # type: ignore[union-attr ]
880938 lr_list = cast (List [float ], self .lr_scheduler .get_lr ())
881- self .lr_scheduler ._get_lr_called_within_step = False # type: ignore[attr-defined ]
939+ self .lr_scheduler ._get_lr_called_within_step = False # type: ignore[union-attr ]
882940 if len (lr_list ) == 1 :
883941 return lr_list [0 ]
884942 else :
0 commit comments