Skip to content

Commit e9e5b45

Browse files
Fixed parameter scheduler bug with CosineAnnealingWarmRestarts (#2938)
* remove codecov * RankProcessFirst * annotations * from class to contextlib * from class to contextlib and test * del test file * uniq folder for test * refactor tests + new assert_test * add to __all__, remove idist import * Apply suggestions from code review * Apply suggestions from code review * Update tests/ignite/distributed/utils/test_native.py * Added local arg and renamed function * add proxy class * annotation * test, proxy class * add optim * name change * test upd/ setter * class fix * Fixed mypy issues * test upd * Fixed failing test_lr_scheduler --------- Co-authored-by: vfdev <[email protected]>
1 parent a99ea7f commit e9e5b45

File tree

2 files changed

+107
-7
lines changed

2 files changed

+107
-7
lines changed

ignite/handlers/param_scheduler.py

Lines changed: 63 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -10,7 +10,7 @@
1010
from typing import Any, cast, Dict, List, Mapping, Optional, Sequence, Tuple, Type, Union
1111

1212
import torch
13-
from torch.optim.lr_scheduler import ReduceLROnPlateau
13+
from torch.optim.lr_scheduler import CosineAnnealingWarmRestarts, ReduceLROnPlateau
1414
from torch.optim.optimizer import Optimizer
1515

1616
# https://github.com/pytorch/ignite/issues/2773
@@ -792,6 +792,61 @@ def simulate_values( # type: ignore[override]
792792
return output
793793

794794

795+
class _CosineAnnealingWarmRestarts:
796+
def __init__(self, lr_scheduler: CosineAnnealingWarmRestarts):
797+
self._lr_scheduler = lr_scheduler
798+
799+
@property
800+
def last_epoch(self) -> int:
801+
return self._lr_scheduler.last_epoch
802+
803+
@last_epoch.setter
804+
def last_epoch(self, value: int) -> None:
805+
self._lr_scheduler.last_epoch = value
806+
807+
@property
808+
def optimizer(self) -> torch.optim.Optimizer:
809+
return self._lr_scheduler.optimizer
810+
811+
def get_lr(self, epoch: Optional[int] = None) -> List[float]:
812+
# TODO: Remove this workaround when pytorch has fixed wrong type hints:
813+
# https://github.com/pytorch/pytorch/pull/102067
814+
# Replace below T_mult -> self._lr_scheduler.T_mult
815+
# Replace below eta_min -> self._lr_scheduler.eta_min
816+
T_mult = cast(int, self._lr_scheduler.T_mult)
817+
eta_min = cast(float, self._lr_scheduler.eta_min)
818+
819+
if epoch is None and self.last_epoch < 0:
820+
epoch = 0
821+
if epoch is None:
822+
epoch = self.last_epoch + 1
823+
self._lr_scheduler.T_cur = self._lr_scheduler.T_cur + 1
824+
if self._lr_scheduler.T_cur >= self._lr_scheduler.T_i:
825+
self._lr_scheduler.T_cur = self._lr_scheduler.T_cur - self._lr_scheduler.T_i
826+
self._lr_scheduler.T_i = self._lr_scheduler.T_i * T_mult
827+
else:
828+
if epoch < 0:
829+
raise ValueError("Expected non-negative epoch, but got {}".format(epoch))
830+
if epoch >= self._lr_scheduler.T_0:
831+
if T_mult == 1:
832+
self._lr_scheduler.T_cur = epoch % self._lr_scheduler.T_0
833+
else:
834+
n = int(math.log((epoch / self._lr_scheduler.T_0 * (T_mult - 1) + 1), T_mult))
835+
self._lr_scheduler.T_cur = epoch - self._lr_scheduler.T_0 * (T_mult**n - 1) / (T_mult - 1)
836+
self._lr_scheduler.T_i = self._lr_scheduler.T_0 * T_mult**n
837+
else:
838+
self._lr_scheduler.T_i = self._lr_scheduler.T_0
839+
self._lr_scheduler.T_cur = epoch
840+
841+
self.last_epoch = math.floor(epoch)
842+
843+
return [
844+
eta_min
845+
+ (base_lr - eta_min) * (1 + math.cos(math.pi * self._lr_scheduler.T_cur / self._lr_scheduler.T_i)) / 2
846+
for base_lr in self._lr_scheduler.base_lrs
847+
]
848+
849+
795850
class LRScheduler(ParamScheduler):
796851
"""A wrapper class to call `torch.optim.lr_scheduler` objects as `ignite` handlers.
797852
@@ -853,7 +908,10 @@ def __init__(
853908
f"but given {type(lr_scheduler)}"
854909
)
855910

856-
self.lr_scheduler = lr_scheduler
911+
self.lr_scheduler: Union[PyTorchLRScheduler, _CosineAnnealingWarmRestarts] = lr_scheduler
912+
if isinstance(lr_scheduler, CosineAnnealingWarmRestarts):
913+
self.lr_scheduler = _CosineAnnealingWarmRestarts(lr_scheduler)
914+
857915
super(LRScheduler, self).__init__(
858916
optimizer=self.lr_scheduler.optimizer,
859917
param_name="lr",
@@ -863,7 +921,7 @@ def __init__(
863921
warnings.warn(
864922
"Please make sure to attach scheduler to Events.ITERATION_COMPLETED "
865923
"instead of Events.ITERATION_STARTED to make sure to use "
866-
"the first lr value from the optimizer, otherwise it is will be skipped"
924+
"the first lr value from the optimizer, otherwise it will be skipped"
867925
)
868926
self.lr_scheduler.last_epoch += 1
869927

@@ -876,9 +934,9 @@ def __call__(self, engine: Optional[Engine], name: Optional[str] = None) -> None
876934
def get_param(self) -> Union[float, List[float]]:
877935
"""Method to get current optimizer's parameter value"""
878936
# Emulate context manager for pytorch>=1.4
879-
self.lr_scheduler._get_lr_called_within_step = True # type: ignore[attr-defined]
937+
self.lr_scheduler._get_lr_called_within_step = True # type: ignore[union-attr]
880938
lr_list = cast(List[float], self.lr_scheduler.get_lr())
881-
self.lr_scheduler._get_lr_called_within_step = False # type: ignore[attr-defined]
939+
self.lr_scheduler._get_lr_called_within_step = False # type: ignore[union-attr]
882940
if len(lr_list) == 1:
883941
return lr_list[0]
884942
else:

tests/ignite/handlers/test_param_scheduler.py

Lines changed: 44 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -3,7 +3,7 @@
33
import numpy as np
44
import pytest
55
import torch
6-
from torch.optim.lr_scheduler import ExponentialLR, StepLR
6+
from torch.optim.lr_scheduler import CosineAnnealingWarmRestarts, ExponentialLR, StepLR
77

88
from ignite.engine import Engine, Events
99
from ignite.handlers.param_scheduler import (
@@ -650,7 +650,7 @@ def test_lr_scheduler(torch_lr_scheduler_cls, kwargs):
650650
state_dict1 = scheduler1.state_dict()
651651

652652
torch_lr_scheduler2 = torch_lr_scheduler_cls(optimizer=optimizer2, **kwargs)
653-
with pytest.warns(UserWarning, match=r"the first lr value from the optimizer, otherwise it is will be skipped"):
653+
with pytest.warns(UserWarning, match=r"the first lr value from the optimizer, otherwise it will be skipped"):
654654
scheduler2 = LRScheduler(torch_lr_scheduler2, use_legacy=True)
655655
state_dict2 = scheduler2.state_dict()
656656

@@ -1362,3 +1362,45 @@ def test_reduce_lr_on_plateau_scheduler_asserts():
13621362
with pytest.raises(ValueError, match=r"Length of argument metric_values should be equal to num_events."):
13631363
metric_values = [0.1, 0.2, 0.3, 0.4, 0.5, 0.6]
13641364
ReduceLROnPlateauScheduler.simulate_values(5, metric_values, 0.01)
1365+
1366+
1367+
@pytest.mark.parametrize("warmup_end_value", [0.23, None])
1368+
@pytest.mark.parametrize("T_0", [1, 12])
1369+
@pytest.mark.parametrize("T_mult", [1, 3])
1370+
def test_create_lr_scheduler_with_warmup_cosine(warmup_end_value, T_0, T_mult):
1371+
lr = 0.2
1372+
steps = 200
1373+
warm_steps = 50
1374+
warm_start = 0.023
1375+
1376+
def get_optim():
1377+
t1 = torch.zeros([1], requires_grad=True)
1378+
return torch.optim.SGD([t1], lr=lr)
1379+
1380+
def get_cos_shed():
1381+
return CosineAnnealingWarmRestarts(optimizer, T_0=T_0, T_mult=T_mult, verbose=False)
1382+
1383+
optimizer = get_optim()
1384+
scheduler = get_cos_shed()
1385+
cosine_lrs = []
1386+
for i in range(steps):
1387+
cosine_lrs.append(optimizer.param_groups[0]["lr"])
1388+
scheduler.step()
1389+
1390+
optimizer = get_optim()
1391+
scheduler = create_lr_scheduler_with_warmup(
1392+
get_cos_shed(), warmup_start_value=warm_start, warmup_end_value=warmup_end_value, warmup_duration=warm_steps
1393+
)
1394+
1395+
warm_lrs = []
1396+
real_warm_steps = warm_steps if warmup_end_value is not None else (warm_steps - 1)
1397+
for epoch in range(real_warm_steps + steps):
1398+
scheduler(None)
1399+
warm_lrs.append(optimizer.param_groups[0]["lr"])
1400+
1401+
if warmup_end_value is not None:
1402+
np.testing.assert_allclose(np.linspace(warm_start, warmup_end_value, warm_steps), warm_lrs[:warm_steps])
1403+
assert warm_lrs[real_warm_steps:] == cosine_lrs
1404+
else:
1405+
np.testing.assert_allclose(np.linspace(warm_start, lr, warm_steps), warm_lrs[:warm_steps])
1406+
assert warm_lrs[real_warm_steps:] == cosine_lrs

0 commit comments

Comments
 (0)