@@ -26,11 +26,7 @@ Ganesh Ajjanagadde (Meta), Rohan Anil (Google), Adnan Aziz (Meta), Pavan Balaji
26
26
27
27
Key distinctives of this implementation include:
28
28
- Homogeneous multi-node multi-GPU support in PyTorch.
29
- - Learning rate grafting [ 3] . Our version of grafting only grafts the second moment/diagonal preconditioner. Momentum/first moment updates are performed separate from grafting. Supports the methods:
30
- - SGD
31
- - Adagrad
32
- - RMSprop
33
- - Adam
29
+ - Learning rate grafting [ 3] . Our version of grafting only grafts the second moment/diagonal preconditioner. Momentum/first moment updates are performed separate from grafting.
34
30
- Supports both normal and AdamW (decoupled) weight decay.
35
31
- Incorporates exponential moving averaging (with or without bias correction) to the estimate the first moment (akin to Adam).
36
32
- Incorporates momentum and Nesterov acceleration.
@@ -91,7 +87,7 @@ optimizer = SGD(
91
87
we would instead use:
92
88
``` python
93
89
import torch
94
- from distributed_shampoo import DistributedShampoo, SGDGraftingConfig
90
+ from distributed_shampoo import DistributedShampoo, SGDPreconditionerConfig
95
91
96
92
model = instantiate_model()
97
93
@@ -104,7 +100,7 @@ optimizer = DistributedShampoo(
104
100
weight_decay = 1e-05 ,
105
101
max_preconditioner_dim = 8192 ,
106
102
precondition_frequency = 100 ,
107
- grafting_config = SGDGraftingConfig (),
103
+ grafting_config = SGDPreconditionerConfig (),
108
104
)
109
105
```
110
106
@@ -129,7 +125,7 @@ optimizer = Adam(
129
125
we would instead use:
130
126
``` python
131
127
import torch
132
- from distributed_shampoo import AdamGraftingConfig , DistributedShampoo
128
+ from distributed_shampoo import AdamPreconditionerConfig , DistributedShampoo
133
129
134
130
model = instantiate_model()
135
131
@@ -142,7 +138,7 @@ optimizer = DistributedShampoo(
142
138
max_preconditioner_dim = 8192 ,
143
139
precondition_frequency = 100 ,
144
140
use_decoupled_weight_decay = False ,
145
- grafting_config = AdamGraftingConfig (
141
+ grafting_config = AdamPreconditionerConfig (
146
142
beta2 = 0.999 ,
147
143
epsilon = 1e-08 ,
148
144
),
@@ -168,7 +164,7 @@ optimizer = Adagrad(
168
164
we would instead use:
169
165
``` python
170
166
import torch
171
- from distributed_shampoo import AdaGradGraftingConfig , DistributedShampoo
167
+ from distributed_shampoo import AdaGradPreconditionerConfig , DistributedShampoo
172
168
173
169
model = instantiate_model()
174
170
@@ -181,7 +177,7 @@ optimizer = DistributedShampoo(
181
177
max_preconditioner_dim = 8192 ,
182
178
precondition_frequency = 100 ,
183
179
use_decoupled_weight_decay = False ,
184
- grafting_config = AdaGradGraftingConfig (
180
+ grafting_config = AdaGradPreconditionerConfig (
185
181
epsilon = 1e-10 ,
186
182
),
187
183
)
@@ -207,7 +203,7 @@ optimizer = AdamW(
207
203
we would instead use:
208
204
``` python
209
205
import torch
210
- from distributed_shampoo import AdamGraftingConfig , DistributedShampoo
206
+ from distributed_shampoo import AdamPreconditionerConfig , DistributedShampoo
211
207
212
208
model = instantiate_model()
213
209
@@ -220,7 +216,7 @@ optimizer = DistributedShampoo(
220
216
max_preconditioner_dim = 8192 ,
221
217
precondition_frequency = 100 ,
222
218
use_decoupled_weight_decay = True ,
223
- grafting_config = AdamGraftingConfig (
219
+ grafting_config = AdamPreconditionerConfig (
224
220
beta2 = 0.999 ,
225
221
epsilon = 1e-08 ,
226
222
),
@@ -308,8 +304,8 @@ optimizer = DistributedShampoo(
308
304
{
309
305
" params" : other_params,
310
306
" lr" : 3e-4 ,
311
- " start_preconditioning_step" , math.inf,
312
- " grafting_config" : AdamGraftingConfig (
307
+ " start_preconditioning_step" : math.inf,
308
+ " grafting_config" : AdamPreconditionerConfig (
313
309
beta2 = 0.95 ,
314
310
epsilon = 1e-10 ,
315
311
),
@@ -343,7 +339,7 @@ import torch
343
339
import torch.distributed as dist
344
340
345
341
from distributed_shampoo import (
346
- AdamGraftingConfig ,
342
+ AdamPreconditionerConfig ,
347
343
DDPDistributedConfig,
348
344
DistributedShampoo,
349
345
)
@@ -376,7 +372,7 @@ optimizer = DistributedShampoo(
376
372
max_preconditioner_dim = 8192 ,
377
373
precondition_frequency = 100 ,
378
374
use_decoupled_weight_decay = True ,
379
- grafting_config = AdamGraftingConfig (
375
+ grafting_config = AdamPreconditionerConfig (
380
376
beta2 = 0.999 ,
381
377
epsilon = 1e-12 ,
382
378
),
@@ -404,7 +400,7 @@ import torch.distributed as dist
404
400
from torch.distributed.fsdp import FullyShardedDataParallel as FSDP
405
401
406
402
from distributed_shampoo import (
407
- AdamGraftingConfig ,
403
+ AdamPreconditionerConfig ,
408
404
compile_fsdp_parameter_metadata,
409
405
DistributedShampoo,
410
406
FSDPDistributedConfig,
@@ -434,7 +430,7 @@ optimizer = DistributedShampoo(
434
430
max_preconditioner_dim = 8192 ,
435
431
precondition_frequency = 100 ,
436
432
use_decoupled_weight_decay = True ,
437
- grafting_config = AdamGraftingConfig (
433
+ grafting_config = AdamPreconditionerConfig (
438
434
beta2 = 0.999 ,
439
435
epsilon = 1e-12 ,
440
436
),
@@ -456,7 +452,7 @@ import torch.distributed as dist
456
452
from torch.distributed.fsdp import FullyShardedDataParallel as FSDP , ShardingStrategy
457
453
458
454
from distributed_shampoo import (
459
- AdamGraftingConfig ,
455
+ AdamPreconditionerConfig ,
460
456
compile_fsdp_parameter_metadata,
461
457
DistributedShampoo,
462
458
HSDPDistributedConfig,
@@ -493,7 +489,7 @@ optimizer = DistributedShampoo(
493
489
max_preconditioner_dim = 8192 ,
494
490
precondition_frequency = 100 ,
495
491
use_decoupled_weight_decay = True ,
496
- grafting_config = AdamGraftingConfig (
492
+ grafting_config = AdamPreconditionerConfig (
497
493
beta2 = 0.999 ,
498
494
epsilon = 1e-12 ,
499
495
),
@@ -519,7 +515,7 @@ import torch.distributed as dist
519
515
from torch.distributed.fsdp import fully_shard
520
516
521
517
from distributed_shampoo import (
522
- AdamGraftingConfig ,
518
+ AdamPreconditionerConfig ,
523
519
DistributedShampoo,
524
520
FullyShardDistributedConfig,
525
521
)
@@ -548,7 +544,7 @@ optimizer = DistributedShampoo(
548
544
max_preconditioner_dim = 8192 ,
549
545
precondition_frequency = 100 ,
550
546
use_decoupled_weight_decay = True ,
551
- grafting_config = AdamGraftingConfig (
547
+ grafting_config = AdamPreconditionerConfig (
552
548
beta2 = 0.999 ,
553
549
epsilon = 1e-12 ,
554
550
),
@@ -570,7 +566,7 @@ import torch.distributed as dist
570
566
from torch.distributed.fsdp import fully_shard
571
567
572
568
from distributed_shampoo import (
573
- AdamGraftingConfig ,
569
+ AdamPreconditionerConfig ,
574
570
DistributedShampoo,
575
571
HybridShardDistributedConfig,
576
572
)
@@ -606,7 +602,7 @@ optimizer = DistributedShampoo(
606
602
max_preconditioner_dim = 8192 ,
607
603
precondition_frequency = 100 ,
608
604
use_decoupled_weight_decay = True ,
609
- grafting_config = AdamGraftingConfig (
605
+ grafting_config = AdamPreconditionerConfig (
610
606
beta2 = 0.999 ,
611
607
epsilon = 1e-12 ,
612
608
),
@@ -678,7 +674,7 @@ With the inclusion of learning rate grafting, we can extract a good learning rat
678
674
momentum = 0.9 ,
679
675
weight_decay = 0.01 ,
680
676
max_preconditioner_dim = 4096 ,
681
- grafting_config = SGDGraftingConfig (),
677
+ grafting_config = SGDPreconditionerConfig (),
682
678
)
683
679
```
684
680
@@ -699,7 +695,7 @@ With the inclusion of learning rate grafting, we can extract a good learning rat
699
695
momentum = 0.9 ,
700
696
weight_decay = 0.01 ,
701
697
precondition_frequency = 100 ,
702
- grafting_config = SGDGraftingConfig (),
698
+ grafting_config = SGDPreconditionerConfig (),
703
699
)
704
700
```
705
701
@@ -718,7 +714,7 @@ With the inclusion of learning rate grafting, we can extract a good learning rat
718
714
momentum = 0.9 ,
719
715
weight_decay = 0.01 ,
720
716
start_preconditioning_step = 300 ,
721
- grafting_config = SGDGraftingConfig (),
717
+ grafting_config = SGDPreconditionerConfig (),
722
718
)
723
719
```
724
720
0 commit comments