Skip to content
Closed
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
19 changes: 8 additions & 11 deletions distributed_shampoo/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -32,8 +32,8 @@
SVDOrthogonalizationConfig,
)
from distributed_shampoo.shampoo_types import (
AdaGradGraftingConfig,
AdamGraftingConfig,
AdaGradPreconditionerConfig,
AdamPreconditionerConfig,
AmortizedPreconditionerConfig,
DDPDistributedConfig,
DefaultEigenvalueCorrectedShampooConfig,
Expand All @@ -48,13 +48,12 @@
FSDPDistributedConfig,
FSDPParamAssignmentStrategy,
FullyShardDistributedConfig,
GraftingConfig,
HSDPDistributedConfig,
HybridShardDistributedConfig,
PreconditionerConfig,
RMSpropGraftingConfig,
RMSpropPreconditionerConfig,
RootInvShampooPreconditionerConfig,
SGDGraftingConfig,
SGDPreconditionerConfig,
ShampooPreconditionerConfig,
ShampooPT2CompileConfig,
SignDescentPreconditionerConfig,
Expand All @@ -65,12 +64,6 @@

__all__ = [
"DistributedShampoo",
# `grafting_config` options.
"GraftingConfig", # Abstract base class.
"SGDGraftingConfig",
"AdaGradGraftingConfig",
"RMSpropGraftingConfig",
"AdamGraftingConfig",
# PT2 compile.
"ShampooPT2CompileConfig",
# `distributed_config` options.
Expand Down Expand Up @@ -98,6 +91,10 @@
"SignDescentPreconditionerConfig", # Based on `PreconditionerConfig`.
"DefaultSpectralDescentPreconditionerConfig", # Default `SpectralDescentPreconditionerConfig` using `NewtonSchulzOrthogonalizationConfig`.
"DefaultSignDescentPreconditionerConfig", # Default `SignDescentPreconditionerConfig`.
"SGDPreconditionerConfig",
"AdaGradPreconditionerConfig",
"RMSpropPreconditionerConfig",
"AdamPreconditionerConfig",
# matrix functions configs.
"RankDeficientStabilityConfig", # Abstract base class.
"PerturbationConfig", # Based on `RankDeficientStabilityConfig`.
Expand Down
172 changes: 90 additions & 82 deletions distributed_shampoo/distributed_shampoo.py
Original file line number Diff line number Diff line change
Expand Up @@ -12,7 +12,6 @@
from collections.abc import Callable, Iterator
from copy import deepcopy
from dataclasses import asdict
from functools import partial
from typing import Any, overload

import torch
Expand Down Expand Up @@ -56,8 +55,8 @@
)

from distributed_shampoo.shampoo_types import (
AdaGradGraftingConfig,
AdamGraftingConfig,
AdaGradPreconditionerConfig,
AdamPreconditionerConfig,
AmortizedPreconditionerConfig,
BETA3,
BETAS,
Expand All @@ -78,7 +77,6 @@
FullyShardDistributedConfig,
GRAFTING_CONFIG,
GRAFTING_PRECONDITIONER_LIST,
GraftingConfig,
HSDPDistributedConfig,
HybridShardDistributedConfig,
LR,
Expand All @@ -94,9 +92,9 @@
PRECONDITIONER_CONFIG,
PreconditionerConfig,
PREVIOUS_GRAD_SELECTOR,
RMSpropGraftingConfig,
RMSpropPreconditionerConfig,
RootInvShampooPreconditionerConfig,
SGDGraftingConfig,
SGDPreconditionerConfig,
SHAMPOO_PRECONDITIONER_LIST,
ShampooPT2CompileConfig,
SignDescentPreconditionerConfig,
Expand Down Expand Up @@ -338,7 +336,7 @@ def __init__(
use_nesterov: bool = False,
use_bias_correction: bool = True,
use_decoupled_weight_decay: bool = True,
grafting_config: GraftingConfig | None = None,
grafting_config: PreconditionerConfig | None = None,
use_pin_memory: bool = False,
shampoo_pt2_compile_config: ShampooPT2CompileConfig | None = None,
distributed_config: DistributedConfig = DefaultSingleDeviceDistributedConfig,
Expand Down Expand Up @@ -577,93 +575,103 @@ def _initialize_blocked_parameters_state(self) -> None:
), "There should not exist any optimizer state yet. Maybe verify that _instantiate_distributor was called before all other instantiation functions."
param_state[block_index] = {}

@torch.no_grad()
def _preconditioner_config_to_list_cls(
self,
state_lists: dict[str, Any],
group: dict[str, Any],
preconditioner_config: PreconditionerConfig,
) -> PreconditionerList | None:
match preconditioner_config:
case None:
return None
case SGDPreconditionerConfig():
return SGDPreconditionerList(
block_list=state_lists[DISTRIBUTOR].local_blocked_params,
)
case (
AdaGradPreconditionerConfig()
| RMSpropPreconditionerConfig()
| AdamPreconditionerConfig()
):
return AdagradPreconditionerList(
block_list=state_lists[DISTRIBUTOR].local_blocked_params,
state=self.state,
block_info_list=state_lists[DISTRIBUTOR].local_block_info_list,
beta2=(
1.0
if type(group[GRAFTING_CONFIG]) is AdaGradPreconditionerConfig
else group[GRAFTING_CONFIG].beta2
),
epsilon=group[GRAFTING_CONFIG].epsilon,
use_bias_correction=type(group[GRAFTING_CONFIG])
is AdamPreconditionerConfig,
)
case (
RootInvShampooPreconditionerConfig()
| EigendecomposedShampooPreconditionerConfig()
| EigenvalueCorrectedShampooPreconditionerConfig()
):
preconditioner_config_to_list_cls: dict[
type[PreconditionerConfig], Callable[..., PreconditionerList]
] = {
RootInvShampooPreconditionerConfig: RootInvShampooPreconditionerList,
EigendecomposedShampooPreconditionerConfig: EigendecomposedShampooPreconditionerList,
EigenvalueCorrectedShampooPreconditionerConfig: EigenvalueCorrectedShampooPreconditionerList,
}
return preconditioner_config_to_list_cls[type(preconditioner_config)](
block_list=state_lists[DISTRIBUTOR].local_blocked_params,
preconditioner_config=group[PRECONDITIONER_CONFIG],
state=self.state,
block_info_list=state_lists[DISTRIBUTOR].local_block_info_list,
beta2=group[BETAS][1],
epsilon=group[EPSILON],
use_bias_correction=group[USE_BIAS_CORRECTION],
)
case SignDescentPreconditionerConfig():
return SignDescentPreconditionerList(
block_list=state_lists[DISTRIBUTOR].local_blocked_params,
preconditioner_config=group[PRECONDITIONER_CONFIG],
)
case SpectralDescentPreconditionerConfig():
assert (
group[DISTRIBUTED_CONFIG].target_parameter_dimensionality == 2
), f"{group[DISTRIBUTED_CONFIG].target_parameter_dimensionality=} must be 2 when using SpectralDescentPreconditionerConfig."
return SpectralDescentPreconditionerList(
block_list=state_lists[DISTRIBUTOR].local_blocked_params,
preconditioner_config=group[PRECONDITIONER_CONFIG],
)
case _:
raise NotImplementedError(f"{preconditioner_config=} not supported!")

@torch.no_grad()
def _instantiate_shampoo_preconditioner_list(self) -> None:
for state_lists, group in zip(
self._per_group_state_lists, self.param_groups, strict=True
):
match group[PRECONDITIONER_CONFIG]:
case (
RootInvShampooPreconditionerConfig()
| EigendecomposedShampooPreconditionerConfig()
| EigenvalueCorrectedShampooPreconditionerConfig()
):
preconditioner_config_to_list_cls: dict[
type[PreconditionerConfig], Callable[..., PreconditionerList]
] = {
RootInvShampooPreconditionerConfig: RootInvShampooPreconditionerList,
EigendecomposedShampooPreconditionerConfig: EigendecomposedShampooPreconditionerList,
EigenvalueCorrectedShampooPreconditionerConfig: EigenvalueCorrectedShampooPreconditionerList,
}
preconditioner_list_cls: Callable[..., PreconditionerList] = (
partial(
preconditioner_config_to_list_cls[
type(group[PRECONDITIONER_CONFIG])
],
state=self.state,
block_info_list=state_lists[
DISTRIBUTOR
].local_block_info_list,
beta2=group[BETAS][1],
epsilon=group[EPSILON],
use_bias_correction=group[USE_BIAS_CORRECTION],
)
)
case SignDescentPreconditionerConfig():
preconditioner_list_cls = partial(SignDescentPreconditionerList)
case SpectralDescentPreconditionerConfig():
assert (
group[DISTRIBUTED_CONFIG].target_parameter_dimensionality == 2
), f"{group[DISTRIBUTED_CONFIG].target_parameter_dimensionality=} must be 2 when using SpectralDescentPreconditionerConfig."
preconditioner_list_cls = SpectralDescentPreconditionerList
case _:
raise NotImplementedError(
f"{group[PRECONDITIONER_CONFIG]=} not supported!"
)

state_lists[SHAMPOO_PRECONDITIONER_LIST] = preconditioner_list_cls(
block_list=state_lists[DISTRIBUTOR].local_blocked_params,
preconditioner_config=group[PRECONDITIONER_CONFIG],
assert (
group[PRECONDITIONER_CONFIG] is not None
), f"{group[PRECONDITIONER_CONFIG]=} is None. Please check the instantiation of DistributedShampoo."
state_lists[SHAMPOO_PRECONDITIONER_LIST] = (
self._preconditioner_config_to_list_cls(
state_lists=state_lists,
group=group,
preconditioner_config=group[PRECONDITIONER_CONFIG],
)
)

@torch.no_grad()
def _instantiate_grafting(self) -> None:
for state_lists, group in zip(
self._per_group_state_lists, self.param_groups, strict=True
):
match group[GRAFTING_CONFIG]:
case None:
state_lists[GRAFTING_PRECONDITIONER_LIST] = None
case SGDGraftingConfig():
state_lists[GRAFTING_PRECONDITIONER_LIST] = SGDPreconditionerList(
block_list=state_lists[DISTRIBUTOR].local_blocked_params,
)
case (
AdaGradGraftingConfig()
| RMSpropGraftingConfig()
| AdamGraftingConfig()
):
state_lists[GRAFTING_PRECONDITIONER_LIST] = (
AdagradPreconditionerList(
block_list=state_lists[DISTRIBUTOR].local_blocked_params,
state=self.state,
block_info_list=state_lists[
DISTRIBUTOR
].local_block_info_list,
beta2=(
1.0
if type(group[GRAFTING_CONFIG]) is AdaGradGraftingConfig
else group[GRAFTING_CONFIG].beta2
),
epsilon=group[GRAFTING_CONFIG].epsilon,
use_bias_correction=type(group[GRAFTING_CONFIG])
is AdamGraftingConfig,
)
)
case _:
raise NotImplementedError(
f"{group[GRAFTING_CONFIG]=} not supported!"
)
state_lists[GRAFTING_PRECONDITIONER_LIST] = (
self._preconditioner_config_to_list_cls(
state_lists=state_lists,
group=group,
preconditioner_config=group[GRAFTING_CONFIG],
)
)

@torch.no_grad()
def _instantiate_steps(self) -> None:
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -27,7 +27,7 @@
from distributed_shampoo.distributor.shampoo_block_info import DTensorBlockInfo
from distributed_shampoo.distributor.shampoo_ddp_distributor import DDPDistributor
from distributed_shampoo.shampoo_types import (
AdaGradGraftingConfig,
AdaGradPreconditionerConfig,
DDPDistributedConfig,
DefaultEigenvalueCorrectedShampooConfig,
DefaultShampooConfig,
Expand Down Expand Up @@ -103,7 +103,7 @@ def _shampoo_optim_factory(
precondition_frequency=1,
start_preconditioning_step=2,
use_decoupled_weight_decay=True,
grafting_config=AdaGradGraftingConfig(epsilon=1e-8),
grafting_config=AdaGradPreconditionerConfig(epsilon=1e-8),
distributed_config=distributed_config,
preconditioner_config=preconditioner_config,
)
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -27,7 +27,7 @@
)
from distributed_shampoo.preconditioner.shampoo_preconditioner_list import SHAMPOO
from distributed_shampoo.shampoo_types import (
AdaGradGraftingConfig,
AdaGradPreconditionerConfig,
DefaultSingleDeviceDistributedConfig,
FSDPDistributedConfig,
HSDPDistributedConfig,
Expand Down Expand Up @@ -124,7 +124,7 @@ def _shampoo_optim_factory(
precondition_frequency=1,
start_preconditioning_step=2,
use_decoupled_weight_decay=True,
grafting_config=AdaGradGraftingConfig(epsilon=1e-8),
grafting_config=AdaGradPreconditionerConfig(epsilon=1e-8),
distributed_config=distributed_config,
)

Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -24,7 +24,7 @@
FullyShardDistributor,
)
from distributed_shampoo.shampoo_types import (
AdaGradGraftingConfig,
AdaGradPreconditionerConfig,
DefaultSingleDeviceDistributedConfig,
FullyShardDistributedConfig,
HybridShardDistributedConfig,
Expand Down Expand Up @@ -128,7 +128,7 @@ def _shampoo_optim_factory(
precondition_frequency=1,
start_preconditioning_step=2,
use_decoupled_weight_decay=True,
grafting_config=AdaGradGraftingConfig(epsilon=1e-8),
grafting_config=AdaGradPreconditionerConfig(epsilon=1e-8),
distributed_config=distributed_config,
)

Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -24,7 +24,7 @@
FullyShardLosslessDistributor,
)
from distributed_shampoo.shampoo_types import (
AdaGradGraftingConfig,
AdaGradPreconditionerConfig,
DefaultSingleDeviceDistributedConfig,
FSDPParamAssignmentStrategy,
FullyShardDistributedConfig,
Expand Down Expand Up @@ -120,7 +120,7 @@ def _shampoo_optim_factory(
precondition_frequency=1,
start_preconditioning_step=2,
use_decoupled_weight_decay=True,
grafting_config=AdaGradGraftingConfig(epsilon=1e-8),
grafting_config=AdaGradPreconditionerConfig(epsilon=1e-8),
distributed_config=distributed_config,
)

Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -29,7 +29,7 @@
from distributed_shampoo.distributor.shampoo_hsdp_distributor import HSDPDistributor
from distributed_shampoo.preconditioner.shampoo_preconditioner_list import SHAMPOO
from distributed_shampoo.shampoo_types import (
AdaGradGraftingConfig,
AdaGradPreconditionerConfig,
DefaultSingleDeviceDistributedConfig,
HSDPDistributedConfig,
SingleDeviceDistributedConfig,
Expand Down Expand Up @@ -124,7 +124,7 @@ def _shampoo_optim_factory(
precondition_frequency=1,
start_preconditioning_step=2,
use_decoupled_weight_decay=True,
grafting_config=AdaGradGraftingConfig(epsilon=1e-8),
grafting_config=AdaGradPreconditionerConfig(epsilon=1e-8),
distributed_config=distributed_config,
)

Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -29,7 +29,7 @@
)
from distributed_shampoo.preconditioner.shampoo_preconditioner_list import SHAMPOO
from distributed_shampoo.shampoo_types import (
AdaGradGraftingConfig,
AdaGradPreconditionerConfig,
DDPDistributedConfig,
DefaultSingleDeviceDistributedConfig,
FullyShardDistributedConfig,
Expand Down Expand Up @@ -146,7 +146,7 @@ def _shampoo_optim_factory(
precondition_frequency=1,
start_preconditioning_step=start_preconditioning_step,
use_decoupled_weight_decay=True,
grafting_config=AdaGradGraftingConfig(epsilon=1e-8),
grafting_config=AdaGradPreconditionerConfig(epsilon=1e-8),
distributed_config=distributed_config,
)

Expand Down
Loading