Skip to content

Commit c207e02

Browse files
committed
MOAR optimizer changes. Woo!
1 parent 42c1f0c commit c207e02

File tree

6 files changed

+61
-51
lines changed

6 files changed

+61
-51
lines changed

README.md

Lines changed: 9 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -23,6 +23,15 @@ I'm fortunate to be able to dedicate significant time and money of my own suppor
2323

2424
## What's New
2525

26+
### Aug 18, 2021
27+
* Optimizer bonanza!
28+
* Add LAMB and LARS optimizers, incl trust ratio clipping options. Tweaked to work properly in PyTorch XLA (tested on TPUs w/ `timm bits` [branch](https://github.com/rwightman/pytorch-image-models/tree/bits_and_tpu/timm/bits))
29+
* Add MADGRAD from FB research w/ a few tweaks (decoupled decay option, step handling that works with PyTorch XLA)
30+
* Some cleanup on all optimizers and factory. No more `.data`, a bit more consistency, unit tests for all!
31+
* SGDP and AdamP still won't work with PyTorch XLA but others should (have yet to test Adabelief, Adafactor, Adahessian myself).
32+
* EfficientNet-V2 XL TF ported weights added, but they don't validate well in PyTorch (L is better). The pre-processing for the V2 TF training is a bit diff and the fine-tuned 21k -> 1k weights are very sensitive and less robust than the 1k weights.
33+
* Added PyTorch trained EfficientNet-V2 'Tiny' w/ GlobalContext attn weights. Only .1-.2 top-1 better than the SE so more of a curiosity for those interested.
34+
2635
### July 12, 2021
2736
* Add XCiT models from [official facebook impl](https://github.com/facebookresearch/xcit). Contributed by [Alexander Soare](https://github.com/alexander-soare)
2837

tests/test_optim.py

Lines changed: 4 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -463,26 +463,26 @@ def test_adafactor(optimizer):
463463
_test_model(optimizer, dict(lr=5e-2))
464464

465465

466-
@pytest.mark.parametrize('optimizer', ['lamb'])
466+
@pytest.mark.parametrize('optimizer', ['lamb', 'lambc'])
467467
def test_lamb(optimizer):
468468
_test_basic_cases(
469469
lambda weight, bias: create_optimizer_v2([weight, bias], optimizer, lr=1e-3)
470470
)
471471
_test_basic_cases(
472472
lambda weight, bias: create_optimizer_v2(
473-
_build_params_dict(weight, bias, lr=3e-3),
473+
_build_params_dict(weight, bias, lr=1e-3),
474474
optimizer,
475475
lr=1e-3)
476476
)
477477
_test_basic_cases(
478478
lambda weight, bias: create_optimizer_v2(
479-
_build_params_dict_single(weight, bias, lr=3e-3),
479+
_build_params_dict_single(weight, bias, lr=1e-3),
480480
optimizer,
481481
lr=1e-3)
482482
)
483483
_test_basic_cases(
484484
lambda weight, bias: create_optimizer_v2(
485-
_build_params_dict_single(weight, bias, lr=3e-3), optimizer)
485+
_build_params_dict_single(weight, bias, lr=1e-3), optimizer)
486486
)
487487
_test_rosenbrock(
488488
lambda params: create_optimizer_v2(params, optimizer, lr=1e-3)

timm/optim/lamb.py

Lines changed: 12 additions & 9 deletions
Original file line numberDiff line numberDiff line change
@@ -73,10 +73,9 @@ class Lamb(Optimizer):
7373
weight_decay (float, optional): weight decay (L2 penalty) (default: 0)
7474
grad_averaging (bool, optional): whether apply (1-beta2) to grad when
7575
calculating running averages of gradient. (default: True)
76-
set_grad_none (bool, optional): whether set grad to None when zero_grad()
77-
method is called. (default: True)
7876
max_grad_norm (float, optional): value used to clip global grad norm (default: 1.0)
79-
use_nvlamb (boolean, optional): Apply adaptive learning rate to 0.0
77+
trust_clip (bool): enable LAMBC trust ratio clipping (default: False)
78+
always_adapt (boolean, optional): Apply adaptive learning rate to 0.0
8079
weight decay parameter (default: False)
8180
8281
.. _Large Batch Optimization for Deep Learning - Training BERT in 76 minutes:
@@ -87,10 +86,11 @@ class Lamb(Optimizer):
8786

8887
def __init__(
8988
self, params, lr=1e-3, bias_correction=True, betas=(0.9, 0.999), eps=1e-6,
90-
weight_decay=0.01, grad_averaging=True, max_grad_norm=1.0, use_nvlamb=False):
89+
weight_decay=0.01, grad_averaging=True, max_grad_norm=1.0, trust_clip=False, always_adapt=False):
9190
defaults = dict(
9291
lr=lr, bias_correction=bias_correction, betas=betas, eps=eps, weight_decay=weight_decay,
93-
grad_averaging=grad_averaging, max_grad_norm=max_grad_norm, use_nvlamb=use_nvlamb)
92+
grad_averaging=grad_averaging, max_grad_norm=max_grad_norm,
93+
trust_clip=trust_clip, always_adapt=always_adapt)
9494
super().__init__(params, defaults)
9595

9696
@torch.no_grad()
@@ -105,7 +105,7 @@ def step(self, closure=None):
105105
with torch.enable_grad():
106106
loss = closure()
107107

108-
device = self.param_groups[0]["params"][0].device
108+
device = self.param_groups[0]['params'][0].device
109109
one_tensor = torch.tensor(1.0, device=device) # because torch.where doesn't handle scalars correctly
110110
global_grad_norm = torch.zeros(1, device=device)
111111
for group in self.param_groups:
@@ -171,9 +171,9 @@ def step(self, closure=None):
171171
if weight_decay != 0:
172172
update.add_(p, alpha=weight_decay)
173173

174-
if weight_decay != 0 or group['use_nvlamb']:
175-
# Layer adaptation. By default, skip layer adaptation on parameters that are
176-
# excluded from weight decay, unless use_nvlamb == True, then always enabled.
174+
if weight_decay != 0 or group['always_adapt']:
175+
# Layer-wise LR adaptation. By default, skip adaptation on parameters that are
176+
# excluded from weight decay, unless always_adapt == True, then always enabled.
177177
w_norm = p.norm(2.0)
178178
g_norm = update.norm(2.0)
179179
# FIXME nested where required since logical and/or not working in PT XLA
@@ -182,6 +182,9 @@ def step(self, closure=None):
182182
torch.where(g_norm > 0, w_norm / g_norm, one_tensor),
183183
one_tensor,
184184
)
185+
if group['trust_clip']:
186+
# LAMBC trust clipping, upper bound fixed at one
187+
trust_ratio = torch.minimum(trust_ratio, one_tensor)
185188
update.mul_(trust_ratio)
186189

187190
p.add_(update, alpha=-group['lr'])

timm/optim/lars.py

Lines changed: 14 additions & 14 deletions
Original file line numberDiff line numberDiff line change
@@ -11,7 +11,7 @@
1111
Copyright 2021 Ross Wightman
1212
"""
1313
import torch
14-
from torch.optim.optimizer import Optimizer, required
14+
from torch.optim.optimizer import Optimizer
1515

1616

1717
class Lars(Optimizer):
@@ -21,31 +21,31 @@ class Lars(Optimizer):
2121
2222
Args:
2323
params (iterable): iterable of parameters to optimize or dicts defining parameter groups.
24-
lr (float, optional): learning rate. (default: 1e-3)
24+
lr (float, optional): learning rate (default: 1.0).
2525
momentum (float, optional): momentum factor (default: 0)
2626
weight_decay (float, optional): weight decay (L2 penalty) (default: 0)
2727
dampening (float, optional): dampening for momentum (default: 0)
2828
nesterov (bool, optional): enables Nesterov momentum (default: False)
2929
trust_coeff (float): trust coefficient for computing adaptive lr / trust_ratio (default: 0.001)
3030
eps (float): eps for division denominator (default: 1e-8)
31-
larc (bool): enable LARC clipping (default: False)
32-
always_scale (bool): always apply LARS scaling, otherwise only when group weight_decay != 0 (default: False)
31+
trust_clip (bool): enable LARC trust ratio clipping (default: False)
32+
always_adapt (bool): always apply LARS LR adapt, otherwise only when group weight_decay != 0 (default: False)
3333
"""
3434

3535
def __init__(
3636
self,
3737
params,
38-
lr=required,
38+
lr=1.0,
3939
momentum=0,
4040
dampening=0,
4141
weight_decay=0,
4242
nesterov=False,
4343
trust_coeff=0.001,
4444
eps=1e-8,
45-
larc=False,
46-
always_scale=False,
45+
trust_clip=False,
46+
always_adapt=False,
4747
):
48-
if lr is not required and lr < 0.0:
48+
if lr < 0.0:
4949
raise ValueError(f"Invalid learning rate: {lr}")
5050
if momentum < 0.0:
5151
raise ValueError(f"Invalid momentum value: {momentum}")
@@ -62,8 +62,8 @@ def __init__(
6262
nesterov=nesterov,
6363
trust_coeff=trust_coeff,
6464
eps=eps,
65-
larc=larc,
66-
always_scale=always_scale,
65+
trust_clip=trust_clip,
66+
always_adapt=always_adapt,
6767
)
6868
super().__init__(params, defaults)
6969

@@ -84,7 +84,7 @@ def step(self, closure=None):
8484
with torch.enable_grad():
8585
loss = closure()
8686

87-
device = self.param_groups[0]["params"][0].device
87+
device = self.param_groups[0]['params'][0].device
8888
one_tensor = torch.tensor(1.0, device=device) # because torch.where doesn't handle scalars correctly
8989

9090
# exclude scaling for params with 0 weight decay
@@ -101,9 +101,9 @@ def step(self, closure=None):
101101
continue
102102
grad = p.grad
103103

104-
# apply LARS scaling, LARC clipping, weight decay
104+
# apply LARS LR adaptation, LARC clipping, weight decay
105105
# ref: https://github.com/NVIDIA/apex/blob/master/apex/parallel/LARC.py
106-
if weight_decay != 0 or group['always_scale']:
106+
if weight_decay != 0 or group['always_adapt']:
107107
w_norm = p.norm(2.0)
108108
g_norm = grad.norm(2.0)
109109
trust_ratio = trust_coeff * w_norm / (g_norm + w_norm * weight_decay + eps)
@@ -113,7 +113,7 @@ def step(self, closure=None):
113113
torch.where(g_norm > 0, trust_ratio, one_tensor),
114114
one_tensor,
115115
)
116-
if group['larc']:
116+
if group['trust_clip']:
117117
trust_ratio = torch.minimum(trust_ratio / group['lr'], one_tensor)
118118
grad.add(p, alpha=weight_decay)
119119
grad.mul_(trust_ratio)

timm/optim/madgrad.py

Lines changed: 18 additions & 22 deletions
Original file line numberDiff line numberDiff line change
@@ -87,42 +87,39 @@ def step(self, closure: Optional[Callable[[], float]] = None) -> Optional[float]
8787
"""Performs a single optimization step.
8888
8989
Arguments:
90-
closure (callable, optional): A closure that reevaluates the model
91-
and returns the loss.
90+
closure (callable, optional): A closure that reevaluates the model and returns the loss.
9291
"""
9392
loss = None
9493
if closure is not None:
9594
with torch.enable_grad():
9695
loss = closure()
9796

98-
step = self.state.setdefault('step', 0) # k
99-
10097
for group in self.param_groups:
101-
eps = group["eps"]
102-
lr = group["lr"] + eps
103-
weight_decay = group["weight_decay"]
104-
momentum = group["momentum"]
105-
98+
eps = group['eps']
99+
lr = group['lr'] + eps
100+
weight_decay = group['weight_decay']
101+
momentum = group['momentum']
106102
ck = 1 - momentum
107-
lamb = lr * math.sqrt(step + 1)
108103

109104
for p in group["params"]:
110105
if p.grad is None:
111106
continue
112107
grad = p.grad
113-
state = self.state[p]
114-
115-
if "grad_sum_sq" not in state:
116-
state["grad_sum_sq"] = torch.zeros_like(p)
117-
state["s"] = torch.zeros_like(p)
118-
if momentum != 0:
119-
state["x0"] = torch.clone(p).detach()
120-
121108
if momentum != 0.0 and grad.is_sparse:
122109
raise RuntimeError("momentum != 0 is not compatible with sparse gradients")
123110

124-
grad_sum_sq = state["grad_sum_sq"]
125-
s = state["s"]
111+
state = self.state[p]
112+
if len(state) == 0:
113+
state['step'] = 0
114+
state['grad_sum_sq'] = torch.zeros_like(p)
115+
state['s'] = torch.zeros_like(p)
116+
if momentum != 0:
117+
state['x0'] = torch.clone(p).detach()
118+
119+
state['step'] += 1
120+
grad_sum_sq = state['grad_sum_sq']
121+
s = state['s']
122+
lamb = lr * math.sqrt(state['step'])
126123

127124
# Apply weight decay
128125
if weight_decay != 0:
@@ -166,7 +163,7 @@ def step(self, closure: Optional[Callable[[], float]] = None) -> Optional[float]
166163
rms = grad_sum_sq.pow(1 / 3).add_(eps)
167164
x0 = p.addcdiv(s, rms, value=1)
168165
else:
169-
x0 = state["x0"]
166+
x0 = state['x0']
170167

171168
# Accumulate second moments
172169
grad_sum_sq.addcmul_(grad, grad, value=lamb)
@@ -184,5 +181,4 @@ def step(self, closure: Optional[Callable[[], float]] = None) -> Optional[float]
184181
# p is a moving average of z
185182
p.mul_(1 - ck).add_(z, alpha=ck)
186183

187-
self.state['step'] += 1
188184
return loss

timm/optim/optim_factory.py

Lines changed: 4 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -164,12 +164,14 @@ def create_optimizer_v2(
164164
optimizer = Adafactor(parameters, **opt_args)
165165
elif opt_lower == 'lamb':
166166
optimizer = Lamb(parameters, **opt_args)
167+
elif opt_lower == 'lambc':
168+
optimizer = Lamb(parameters, trust_clip=True, **opt_args)
167169
elif opt_lower == 'larc':
168-
optimizer = Lars(parameters, momentum=momentum, larc=True, **opt_args)
170+
optimizer = Lars(parameters, momentum=momentum, trust_clip=True, **opt_args)
169171
elif opt_lower == 'lars':
170172
optimizer = Lars(parameters, momentum=momentum, **opt_args)
171173
elif opt_lower == 'nlarc':
172-
optimizer = Lars(parameters, momentum=momentum, larc=True, nesterov=True, **opt_args)
174+
optimizer = Lars(parameters, momentum=momentum, trust_clip=True, nesterov=True, **opt_args)
173175
elif opt_lower == 'nlars':
174176
optimizer = Lars(parameters, momentum=momentum, nesterov=True, **opt_args)
175177
elif opt_lower == 'madgrad':

0 commit comments

Comments
 (0)