Skip to content
Closed
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
3 changes: 2 additions & 1 deletion distributed_shampoo/distributed_shampoo.py
Original file line number Diff line number Diff line change
Expand Up @@ -79,6 +79,7 @@
MAX_PRECONDITIONER_DIM,
MOMENTUM,
MOMENTUM_LIST,
MuonGraftingConfig,
PARAMS,
PRECONDITION_FREQUENCY,
PRECONDITIONER_CONFIG,
Expand Down Expand Up @@ -588,7 +589,7 @@ def _instantiate_grafting(self) -> None:
match group[GRAFTING_CONFIG]:
case None:
state_lists[GRAFTING_PRECONDITIONER_LIST] = None
case SGDGraftingConfig():
case SGDGraftingConfig() | MuonGraftingConfig():
state_lists[GRAFTING_PRECONDITIONER_LIST] = SGDPreconditionerList(
block_list=state_lists[DISTRIBUTOR].local_blocked_params,
)
Expand Down
51 changes: 51 additions & 0 deletions distributed_shampoo/gpu_tests/shampoo_grafting_test.py
Original file line number Diff line number Diff line change
Expand Up @@ -21,6 +21,7 @@
AdamGraftingConfig,
DefaultShampooConfig,
EigendecomposedShampooPreconditionerConfig,
MuonGraftingConfig,
PreconditionerConfig,
RMSpropGraftingConfig,
SGDGraftingConfig,
Expand Down Expand Up @@ -287,6 +288,56 @@ def test_sgd_grafting(
device=device,
)

@parametrize(
"preconditioner_config",
(
DefaultShampooConfig,
EigendecomposedShampooPreconditionerConfig(),
),
)
@parametrize("device", available_devices)
@parametrize("use_nesterov", (True, False))
@parametrize("weight_decay", (0.0, 0.3))
def test_muon_grafting(
self,
weight_decay: float,
use_nesterov: bool,
device: torch.device,
preconditioner_config: PreconditionerConfig,
) -> None:
optim_factory = partial(
DistributedShampooGraftingTest._optim_factory,
lr=0.1,
momentum=0.9,
weight_decay=weight_decay,
)
experimental_optim_factory = partial(
optim_factory,
optim_cls=DistributedShampoo,
betas=(0.0, 0.9),
epsilon=1e-10,
max_preconditioner_dim=10,
precondition_frequency=1,
start_preconditioning_step=math.inf,
use_nesterov=use_nesterov,
use_decoupled_weight_decay=False,
grafting_config=MuonGraftingConfig(),
preconditioner_config=preconditioner_config,
)

compare_two_optimizers_on_weight_and_loss(
control_optim_factory=partial(
optim_factory,
optim_cls=SGD,
nesterov=use_nesterov,
),
experimental_optim_factory=experimental_optim_factory,
# Setting model_linear_layers_dims to (10, 10) to ensure a simple model structure,
# as SGD can be sensitive to the choice of model architecture.
model_linear_layers_dims=(10, 10),
device=device,
)

compare_optimizer_on_cpu_and_device(
optim_factory=experimental_optim_factory,
# Using the same model_linear_layers_dims for consistency in testing across devices.
Expand Down
9 changes: 9 additions & 0 deletions distributed_shampoo/shampoo_types.py
Original file line number Diff line number Diff line change
Expand Up @@ -619,6 +619,15 @@ class SGDGraftingConfig(GraftingConfig):
"""Configuration for grafting from SGD."""


@dataclass
class MuonGraftingConfig(SGDGraftingConfig):
"""Configuration for grafting from Momentum-SGD (Muon).

Note:
The momentum parameter is set by momentum in DistributedShampoo's hyperparameters.
"""


@dataclass(kw_only=True)
class AdaGradGraftingConfig(GraftingConfig):
"""Configuration for grafting from AdaGrad.
Expand Down