Skip to content

Commit ef22dd3

Browse files
tsunghsienleefacebook-github-bot
authored andcommitted
Merge GraftingConfig into PreconditionerConfig (#242)
Summary: Grafting is also instantiating a class whose base class is `PreconditionerList` so it makes sense to merge `GraftingConfig` into `PreconditionerConfig`. Moreover, this enriches the choices of both grafting and preconditioners. Differential Revision: D81237972
1 parent 9da988e commit ef22dd3

14 files changed

+233
-259
lines changed

distributed_shampoo/__init__.py

Lines changed: 8 additions & 11 deletions
Original file line numberDiff line numberDiff line change
@@ -32,8 +32,8 @@
3232
SVDOrthogonalizationConfig,
3333
)
3434
from distributed_shampoo.shampoo_types import (
35-
AdaGradGraftingConfig,
36-
AdamGraftingConfig,
35+
AdaGradPreconditionerConfig,
36+
AdamPreconditionerConfig,
3737
AmortizedPreconditionerConfig,
3838
DDPDistributedConfig,
3939
DefaultEigenvalueCorrectedShampooConfig,
@@ -48,13 +48,12 @@
4848
FSDPDistributedConfig,
4949
FSDPParamAssignmentStrategy,
5050
FullyShardDistributedConfig,
51-
GraftingConfig,
5251
HSDPDistributedConfig,
5352
HybridShardDistributedConfig,
5453
PreconditionerConfig,
55-
RMSpropGraftingConfig,
54+
RMSpropPreconditionerConfig,
5655
RootInvShampooPreconditionerConfig,
57-
SGDGraftingConfig,
56+
SGDPreconditionerConfig,
5857
ShampooPreconditionerConfig,
5958
ShampooPT2CompileConfig,
6059
SignDescentPreconditionerConfig,
@@ -65,12 +64,6 @@
6564

6665
__all__ = [
6766
"DistributedShampoo",
68-
# `grafting_config` options.
69-
"GraftingConfig", # Abstract base class.
70-
"SGDGraftingConfig",
71-
"AdaGradGraftingConfig",
72-
"RMSpropGraftingConfig",
73-
"AdamGraftingConfig",
7467
# PT2 compile.
7568
"ShampooPT2CompileConfig",
7669
# `distributed_config` options.
@@ -98,6 +91,10 @@
9891
"SignDescentPreconditionerConfig", # Based on `PreconditionerConfig`.
9992
"DefaultSpectralDescentPreconditionerConfig", # Default `SpectralDescentPreconditionerConfig` using `NewtonSchulzOrthogonalizationConfig`.
10093
"DefaultSignDescentPreconditionerConfig", # Default `SignDescentPreconditionerConfig`.
94+
"SGDPreconditionerConfig",
95+
"AdaPreconditionerConfig",
96+
"RMSpropPreconditionerConfig",
97+
"AdamPreconditionerConfig",
10198
# matrix functions configs.
10299
"RankDeficientStabilityConfig", # Abstract base class.
103100
"PerturbationConfig", # Based on `RankDeficientStabilityConfig`.

distributed_shampoo/distributed_shampoo.py

Lines changed: 90 additions & 82 deletions
Original file line numberDiff line numberDiff line change
@@ -12,7 +12,6 @@
1212
from collections.abc import Callable, Iterator
1313
from copy import deepcopy
1414
from dataclasses import asdict
15-
from functools import partial
1615
from typing import Any, overload
1716

1817
import torch
@@ -56,8 +55,8 @@
5655
)
5756

5857
from distributed_shampoo.shampoo_types import (
59-
AdaGradGraftingConfig,
60-
AdamGraftingConfig,
58+
AdaGradPreconditionerConfig,
59+
AdamPreconditionerConfig,
6160
AmortizedPreconditionerConfig,
6261
BETA3,
6362
BETAS,
@@ -78,7 +77,6 @@
7877
FullyShardDistributedConfig,
7978
GRAFTING_CONFIG,
8079
GRAFTING_PRECONDITIONER_LIST,
81-
GraftingConfig,
8280
HSDPDistributedConfig,
8381
HybridShardDistributedConfig,
8482
LR,
@@ -94,9 +92,9 @@
9492
PRECONDITIONER_CONFIG,
9593
PreconditionerConfig,
9694
PREVIOUS_GRAD_SELECTOR,
97-
RMSpropGraftingConfig,
95+
RMSpropPreconditionerConfig,
9896
RootInvShampooPreconditionerConfig,
99-
SGDGraftingConfig,
97+
SGDPreconditionerConfig,
10098
SHAMPOO_PRECONDITIONER_LIST,
10199
ShampooPT2CompileConfig,
102100
SignDescentPreconditionerConfig,
@@ -338,7 +336,7 @@ def __init__(
338336
use_nesterov: bool = False,
339337
use_bias_correction: bool = True,
340338
use_decoupled_weight_decay: bool = True,
341-
grafting_config: GraftingConfig | None = None,
339+
grafting_config: PreconditionerConfig | None = None,
342340
use_pin_memory: bool = False,
343341
shampoo_pt2_compile_config: ShampooPT2CompileConfig | None = None,
344342
distributed_config: DistributedConfig = DefaultSingleDeviceDistributedConfig,
@@ -577,93 +575,103 @@ def _initialize_blocked_parameters_state(self) -> None:
577575
), "There should not exist any optimizer state yet. Maybe verify that _instantiate_distributor was called before all other instantiation functions."
578576
param_state[block_index] = {}
579577

578+
@torch.no_grad()
579+
def _preconditioner_config_to_list_cls(
580+
self,
581+
state_lists: dict[str, Any],
582+
group: dict[str, Any],
583+
preconditioner_config: PreconditionerConfig,
584+
) -> PreconditionerList | None:
585+
match preconditioner_config:
586+
case None:
587+
return None
588+
case SGDPreconditionerConfig():
589+
return SGDPreconditionerList(
590+
block_list=state_lists[DISTRIBUTOR].local_blocked_params,
591+
)
592+
case (
593+
AdaGradPreconditionerConfig()
594+
| RMSpropPreconditionerConfig()
595+
| AdamPreconditionerConfig()
596+
):
597+
return AdagradPreconditionerList(
598+
block_list=state_lists[DISTRIBUTOR].local_blocked_params,
599+
state=self.state,
600+
block_info_list=state_lists[DISTRIBUTOR].local_block_info_list,
601+
beta2=(
602+
1.0
603+
if type(group[GRAFTING_CONFIG]) is AdaGradPreconditionerConfig
604+
else group[GRAFTING_CONFIG].beta2
605+
),
606+
epsilon=group[GRAFTING_CONFIG].epsilon,
607+
use_bias_correction=type(group[GRAFTING_CONFIG])
608+
is AdamPreconditionerConfig,
609+
)
610+
case (
611+
RootInvShampooPreconditionerConfig()
612+
| EigendecomposedShampooPreconditionerConfig()
613+
| EigenvalueCorrectedShampooPreconditionerConfig()
614+
):
615+
preconditioner_config_to_list_cls: dict[
616+
type[PreconditionerConfig], Callable[..., PreconditionerList]
617+
] = {
618+
RootInvShampooPreconditionerConfig: RootInvShampooPreconditionerList,
619+
EigendecomposedShampooPreconditionerConfig: EigendecomposedShampooPreconditionerList,
620+
EigenvalueCorrectedShampooPreconditionerConfig: EigenvalueCorrectedShampooPreconditionerList,
621+
}
622+
return preconditioner_config_to_list_cls[type(preconditioner_config)](
623+
block_list=state_lists[DISTRIBUTOR].local_blocked_params,
624+
preconditioner_config=group[PRECONDITIONER_CONFIG],
625+
state=self.state,
626+
block_info_list=state_lists[DISTRIBUTOR].local_block_info_list,
627+
beta2=group[BETAS][1],
628+
epsilon=group[EPSILON],
629+
use_bias_correction=group[USE_BIAS_CORRECTION],
630+
)
631+
case SignDescentPreconditionerConfig():
632+
return SignDescentPreconditionerList(
633+
block_list=state_lists[DISTRIBUTOR].local_blocked_params,
634+
preconditioner_config=group[PRECONDITIONER_CONFIG],
635+
)
636+
case SpectralDescentPreconditionerConfig():
637+
assert (
638+
group[DISTRIBUTED_CONFIG].target_parameter_dimensionality == 2
639+
), f"{group[DISTRIBUTED_CONFIG].target_parameter_dimensionality=} must be 2 when using SpectralDescentPreconditionerConfig."
640+
return SpectralDescentPreconditionerList(
641+
block_list=state_lists[DISTRIBUTOR].local_blocked_params,
642+
preconditioner_config=group[PRECONDITIONER_CONFIG],
643+
)
644+
case _:
645+
raise NotImplementedError(f"{preconditioner_config=} not supported!")
646+
580647
@torch.no_grad()
581648
def _instantiate_shampoo_preconditioner_list(self) -> None:
582649
for state_lists, group in zip(
583650
self._per_group_state_lists, self.param_groups, strict=True
584651
):
585-
match group[PRECONDITIONER_CONFIG]:
586-
case (
587-
RootInvShampooPreconditionerConfig()
588-
| EigendecomposedShampooPreconditionerConfig()
589-
| EigenvalueCorrectedShampooPreconditionerConfig()
590-
):
591-
preconditioner_config_to_list_cls: dict[
592-
type[PreconditionerConfig], Callable[..., PreconditionerList]
593-
] = {
594-
RootInvShampooPreconditionerConfig: RootInvShampooPreconditionerList,
595-
EigendecomposedShampooPreconditionerConfig: EigendecomposedShampooPreconditionerList,
596-
EigenvalueCorrectedShampooPreconditionerConfig: EigenvalueCorrectedShampooPreconditionerList,
597-
}
598-
preconditioner_list_cls: Callable[..., PreconditionerList] = (
599-
partial(
600-
preconditioner_config_to_list_cls[
601-
type(group[PRECONDITIONER_CONFIG])
602-
],
603-
state=self.state,
604-
block_info_list=state_lists[
605-
DISTRIBUTOR
606-
].local_block_info_list,
607-
beta2=group[BETAS][1],
608-
epsilon=group[EPSILON],
609-
use_bias_correction=group[USE_BIAS_CORRECTION],
610-
)
611-
)
612-
case SignDescentPreconditionerConfig():
613-
preconditioner_list_cls = partial(SignDescentPreconditionerList)
614-
case SpectralDescentPreconditionerConfig():
615-
assert (
616-
group[DISTRIBUTED_CONFIG].target_parameter_dimensionality == 2
617-
), f"{group[DISTRIBUTED_CONFIG].target_parameter_dimensionality=} must be 2 when using SpectralDescentPreconditionerConfig."
618-
preconditioner_list_cls = SpectralDescentPreconditionerList
619-
case _:
620-
raise NotImplementedError(
621-
f"{group[PRECONDITIONER_CONFIG]=} not supported!"
622-
)
623-
624-
state_lists[SHAMPOO_PRECONDITIONER_LIST] = preconditioner_list_cls(
625-
block_list=state_lists[DISTRIBUTOR].local_blocked_params,
626-
preconditioner_config=group[PRECONDITIONER_CONFIG],
652+
assert (
653+
group[PRECONDITIONER_CONFIG] is not None
654+
), f"{group[PRECONDITIONER_CONFIG]=} is None. Please check the instantiation of DistributedShampoo."
655+
state_lists[SHAMPOO_PRECONDITIONER_LIST] = (
656+
self._preconditioner_config_to_list_cls(
657+
state_lists=state_lists,
658+
group=group,
659+
preconditioner_config=group[PRECONDITIONER_CONFIG],
660+
)
627661
)
628662

629663
@torch.no_grad()
630664
def _instantiate_grafting(self) -> None:
631665
for state_lists, group in zip(
632666
self._per_group_state_lists, self.param_groups, strict=True
633667
):
634-
match group[GRAFTING_CONFIG]:
635-
case None:
636-
state_lists[GRAFTING_PRECONDITIONER_LIST] = None
637-
case SGDGraftingConfig():
638-
state_lists[GRAFTING_PRECONDITIONER_LIST] = SGDPreconditionerList(
639-
block_list=state_lists[DISTRIBUTOR].local_blocked_params,
640-
)
641-
case (
642-
AdaGradGraftingConfig()
643-
| RMSpropGraftingConfig()
644-
| AdamGraftingConfig()
645-
):
646-
state_lists[GRAFTING_PRECONDITIONER_LIST] = (
647-
AdagradPreconditionerList(
648-
block_list=state_lists[DISTRIBUTOR].local_blocked_params,
649-
state=self.state,
650-
block_info_list=state_lists[
651-
DISTRIBUTOR
652-
].local_block_info_list,
653-
beta2=(
654-
1.0
655-
if type(group[GRAFTING_CONFIG]) is AdaGradGraftingConfig
656-
else group[GRAFTING_CONFIG].beta2
657-
),
658-
epsilon=group[GRAFTING_CONFIG].epsilon,
659-
use_bias_correction=type(group[GRAFTING_CONFIG])
660-
is AdamGraftingConfig,
661-
)
662-
)
663-
case _:
664-
raise NotImplementedError(
665-
f"{group[GRAFTING_CONFIG]=} not supported!"
666-
)
668+
state_lists[GRAFTING_PRECONDITIONER_LIST] = (
669+
self._preconditioner_config_to_list_cls(
670+
state_lists=state_lists,
671+
group=group,
672+
preconditioner_config=group[GRAFTING_CONFIG],
673+
)
674+
)
667675

668676
@torch.no_grad()
669677
def _instantiate_steps(self) -> None:

distributed_shampoo/distributor/gpu_tests/shampoo_ddp_distributor_test.py

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -27,7 +27,7 @@
2727
from distributed_shampoo.distributor.shampoo_block_info import DTensorBlockInfo
2828
from distributed_shampoo.distributor.shampoo_ddp_distributor import DDPDistributor
2929
from distributed_shampoo.shampoo_types import (
30-
AdaGradGraftingConfig,
30+
AdaGradPreconditionerConfig,
3131
DDPDistributedConfig,
3232
DefaultEigenvalueCorrectedShampooConfig,
3333
DefaultShampooConfig,
@@ -103,7 +103,7 @@ def _shampoo_optim_factory(
103103
precondition_frequency=1,
104104
start_preconditioning_step=2,
105105
use_decoupled_weight_decay=True,
106-
grafting_config=AdaGradGraftingConfig(epsilon=1e-8),
106+
grafting_config=AdaGradPreconditionerConfig(epsilon=1e-8),
107107
distributed_config=distributed_config,
108108
preconditioner_config=preconditioner_config,
109109
)

distributed_shampoo/distributor/gpu_tests/shampoo_fsdp_distributor_test.py

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -27,7 +27,7 @@
2727
)
2828
from distributed_shampoo.preconditioner.shampoo_preconditioner_list import SHAMPOO
2929
from distributed_shampoo.shampoo_types import (
30-
AdaGradGraftingConfig,
30+
AdaGradPreconditionerConfig,
3131
DefaultSingleDeviceDistributedConfig,
3232
FSDPDistributedConfig,
3333
HSDPDistributedConfig,
@@ -124,7 +124,7 @@ def _shampoo_optim_factory(
124124
precondition_frequency=1,
125125
start_preconditioning_step=2,
126126
use_decoupled_weight_decay=True,
127-
grafting_config=AdaGradGraftingConfig(epsilon=1e-8),
127+
grafting_config=AdaGradPreconditionerConfig(epsilon=1e-8),
128128
distributed_config=distributed_config,
129129
)
130130

distributed_shampoo/distributor/gpu_tests/shampoo_fully_shard_distributor_test.py

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -24,7 +24,7 @@
2424
FullyShardDistributor,
2525
)
2626
from distributed_shampoo.shampoo_types import (
27-
AdaGradGraftingConfig,
27+
AdaGradPreconditionerConfig,
2828
DefaultSingleDeviceDistributedConfig,
2929
FullyShardDistributedConfig,
3030
HybridShardDistributedConfig,
@@ -128,7 +128,7 @@ def _shampoo_optim_factory(
128128
precondition_frequency=1,
129129
start_preconditioning_step=2,
130130
use_decoupled_weight_decay=True,
131-
grafting_config=AdaGradGraftingConfig(epsilon=1e-8),
131+
grafting_config=AdaGradPreconditionerConfig(epsilon=1e-8),
132132
distributed_config=distributed_config,
133133
)
134134

distributed_shampoo/distributor/gpu_tests/shampoo_fully_shard_lossless_distributor_test.py

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -24,7 +24,7 @@
2424
FullyShardLosslessDistributor,
2525
)
2626
from distributed_shampoo.shampoo_types import (
27-
AdaGradGraftingConfig,
27+
AdaGradPreconditionerConfig,
2828
DefaultSingleDeviceDistributedConfig,
2929
FSDPParamAssignmentStrategy,
3030
FullyShardDistributedConfig,
@@ -120,7 +120,7 @@ def _shampoo_optim_factory(
120120
precondition_frequency=1,
121121
start_preconditioning_step=2,
122122
use_decoupled_weight_decay=True,
123-
grafting_config=AdaGradGraftingConfig(epsilon=1e-8),
123+
grafting_config=AdaGradPreconditionerConfig(epsilon=1e-8),
124124
distributed_config=distributed_config,
125125
)
126126

distributed_shampoo/distributor/gpu_tests/shampoo_hsdp_distributor_test.py

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -29,7 +29,7 @@
2929
from distributed_shampoo.distributor.shampoo_hsdp_distributor import HSDPDistributor
3030
from distributed_shampoo.preconditioner.shampoo_preconditioner_list import SHAMPOO
3131
from distributed_shampoo.shampoo_types import (
32-
AdaGradGraftingConfig,
32+
AdaGradPreconditionerConfig,
3333
DefaultSingleDeviceDistributedConfig,
3434
HSDPDistributedConfig,
3535
SingleDeviceDistributedConfig,
@@ -124,7 +124,7 @@ def _shampoo_optim_factory(
124124
precondition_frequency=1,
125125
start_preconditioning_step=2,
126126
use_decoupled_weight_decay=True,
127-
grafting_config=AdaGradGraftingConfig(epsilon=1e-8),
127+
grafting_config=AdaGradPreconditionerConfig(epsilon=1e-8),
128128
distributed_config=distributed_config,
129129
)
130130

distributed_shampoo/distributor/gpu_tests/shampoo_hybrid_shard_distributor_test.py

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -29,7 +29,7 @@
2929
)
3030
from distributed_shampoo.preconditioner.shampoo_preconditioner_list import SHAMPOO
3131
from distributed_shampoo.shampoo_types import (
32-
AdaGradGraftingConfig,
32+
AdaGradPreconditionerConfig,
3333
DDPDistributedConfig,
3434
DefaultSingleDeviceDistributedConfig,
3535
FullyShardDistributedConfig,
@@ -146,7 +146,7 @@ def _shampoo_optim_factory(
146146
precondition_frequency=1,
147147
start_preconditioning_step=start_preconditioning_step,
148148
use_decoupled_weight_decay=True,
149-
grafting_config=AdaGradGraftingConfig(epsilon=1e-8),
149+
grafting_config=AdaGradPreconditionerConfig(epsilon=1e-8),
150150
distributed_config=distributed_config,
151151
)
152152

0 commit comments

Comments
 (0)