From d07ad6daf61d5962986c1de57a4168b0a4aeddcf Mon Sep 17 00:00:00 2001 From: Vishal Pandey <68536727+Vishal-sys-code@users.noreply.github.com> Date: Fri, 8 Aug 2025 02:08:40 +0530 Subject: [PATCH] feat: Implement Muon Optimizer Grafting --- distributed_shampoo/distributed_shampoo.py | 3 +- .../gpu_tests/shampoo_grafting_test.py | 51 +++++++++++++++++++ distributed_shampoo/shampoo_types.py | 9 ++++ 3 files changed, 62 insertions(+), 1 deletion(-) diff --git a/distributed_shampoo/distributed_shampoo.py b/distributed_shampoo/distributed_shampoo.py index 49b7a4b..f71aaa7 100644 --- a/distributed_shampoo/distributed_shampoo.py +++ b/distributed_shampoo/distributed_shampoo.py @@ -79,6 +79,7 @@ MAX_PRECONDITIONER_DIM, MOMENTUM, MOMENTUM_LIST, + MuonGraftingConfig, PARAMS, PRECONDITION_FREQUENCY, PRECONDITIONER_CONFIG, @@ -588,7 +589,7 @@ def _instantiate_grafting(self) -> None: match group[GRAFTING_CONFIG]: case None: state_lists[GRAFTING_PRECONDITIONER_LIST] = None - case SGDGraftingConfig(): + case SGDGraftingConfig() | MuonGraftingConfig(): state_lists[GRAFTING_PRECONDITIONER_LIST] = SGDPreconditionerList( block_list=state_lists[DISTRIBUTOR].local_blocked_params, ) diff --git a/distributed_shampoo/gpu_tests/shampoo_grafting_test.py b/distributed_shampoo/gpu_tests/shampoo_grafting_test.py index 702a472..eff08f5 100644 --- a/distributed_shampoo/gpu_tests/shampoo_grafting_test.py +++ b/distributed_shampoo/gpu_tests/shampoo_grafting_test.py @@ -21,6 +21,7 @@ AdamGraftingConfig, DefaultShampooConfig, EigendecomposedShampooPreconditionerConfig, + MuonGraftingConfig, PreconditionerConfig, RMSpropGraftingConfig, SGDGraftingConfig, @@ -287,6 +288,56 @@ def test_sgd_grafting( device=device, ) + @parametrize( + "preconditioner_config", + ( + DefaultShampooConfig, + EigendecomposedShampooPreconditionerConfig(), + ), + ) + @parametrize("device", available_devices) + @parametrize("use_nesterov", (True, False)) + @parametrize("weight_decay", (0.0, 0.3)) + def test_muon_grafting( + self, + weight_decay: float, + use_nesterov: bool, + device: torch.device, + preconditioner_config: PreconditionerConfig, + ) -> None: + optim_factory = partial( + DistributedShampooGraftingTest._optim_factory, + lr=0.1, + momentum=0.9, + weight_decay=weight_decay, + ) + experimental_optim_factory = partial( + optim_factory, + optim_cls=DistributedShampoo, + betas=(0.0, 0.9), + epsilon=1e-10, + max_preconditioner_dim=10, + precondition_frequency=1, + start_preconditioning_step=math.inf, + use_nesterov=use_nesterov, + use_decoupled_weight_decay=False, + grafting_config=MuonGraftingConfig(), + preconditioner_config=preconditioner_config, + ) + + compare_two_optimizers_on_weight_and_loss( + control_optim_factory=partial( + optim_factory, + optim_cls=SGD, + nesterov=use_nesterov, + ), + experimental_optim_factory=experimental_optim_factory, + # Setting model_linear_layers_dims to (10, 10) to ensure a simple model structure, + # as SGD can be sensitive to the choice of model architecture. + model_linear_layers_dims=(10, 10), + device=device, + ) + compare_optimizer_on_cpu_and_device( optim_factory=experimental_optim_factory, # Using the same model_linear_layers_dims for consistency in testing across devices. diff --git a/distributed_shampoo/shampoo_types.py b/distributed_shampoo/shampoo_types.py index 9d51b1f..e39af36 100644 --- a/distributed_shampoo/shampoo_types.py +++ b/distributed_shampoo/shampoo_types.py @@ -617,6 +617,15 @@ class SGDGraftingConfig(GraftingConfig): """Configuration for grafting from SGD.""" +@dataclass +class MuonGraftingConfig(SGDGraftingConfig): + """Configuration for grafting from Momentum-SGD (Muon). + + Note: + The momentum parameter is set by momentum in DistributedShampoo's hyperparameters. + """ + + @dataclass(kw_only=True) class AdaGradGraftingConfig(GraftingConfig): """Configuration for grafting from AdaGrad.