From 513b63a4a97839a2767f91adf03c05bc33a2a6a3 Mon Sep 17 00:00:00 2001 From: Supadchaya Puangpontip Date: Tue, 7 Oct 2025 21:03:00 -0700 Subject: [PATCH] Add CPU support for rowwise adagrad with counter (#4986) Summary: X-link: https://github.com/facebookresearch/FBGEMM/pull/1999 Add CPU support for rowwise adagrad with counter. This is to unblock MAI v3 on MTIA. Differential Revision: D81998586 --- fbgemm_gpu/cmake/tbe_sources.py | 2 +- fbgemm_gpu/codegen/genscript/optimizers.py | 2 +- .../tbe/training/backward_optimizers_test.py | 78 ++++++++++++++++++- 3 files changed, 77 insertions(+), 5 deletions(-) diff --git a/fbgemm_gpu/cmake/tbe_sources.py b/fbgemm_gpu/cmake/tbe_sources.py index 31200b6190..82092cc173 100644 --- a/fbgemm_gpu/cmake/tbe_sources.py +++ b/fbgemm_gpu/cmake/tbe_sources.py @@ -12,6 +12,7 @@ "adagrad", "rowwise_adagrad", "sgd", + "rowwise_adagrad_with_counter", ] # To be populated in the subsequent diffs @@ -24,7 +25,6 @@ "partial_rowwise_adam", "partial_rowwise_lamb", "none", - "rowwise_adagrad_with_counter", ] DEPRECATED_OPTIMIZERS = [ diff --git a/fbgemm_gpu/codegen/genscript/optimizers.py b/fbgemm_gpu/codegen/genscript/optimizers.py index 7141d78e10..c61e6843f9 100644 --- a/fbgemm_gpu/codegen/genscript/optimizers.py +++ b/fbgemm_gpu/codegen/genscript/optimizers.py @@ -652,7 +652,7 @@ def rowwise_adagrad_with_counter() -> Dict[str, Any]: "split_weight_update": split_weight_update, "split_post_update": "", "split_weight_update_cpu": split_weight_update_cpu, - "has_cpu_support": False, + "has_cpu_support": True, "has_gpu_support": True, "has_vbe_support": True, "has_global_weight_decay_support": False, diff --git a/fbgemm_gpu/test/tbe/training/backward_optimizers_test.py b/fbgemm_gpu/test/tbe/training/backward_optimizers_test.py index a6bc22770b..4306ea2d1c 100644 --- a/fbgemm_gpu/test/tbe/training/backward_optimizers_test.py +++ b/fbgemm_gpu/test/tbe/training/backward_optimizers_test.py @@ -134,8 +134,7 @@ def execute_backward_optimizers_( # noqa C901 ] ) and ( - use_cpu - or optimizer != OptimType.EXACT_ROWWISE_ADAGRAD + optimizer != OptimType.EXACT_ROWWISE_ADAGRAD or weight_decay_mode not in [ WeightDecayMode.COUNTER, @@ -1205,7 +1204,7 @@ def test_backward_optimizers_partial_rowwise_adam_bf16_momentum( # noqa C901 deadline=None, suppress_health_check=[HealthCheck.filter_too_much, HealthCheck.data_too_large], ) - @unittest.skipIf(*gpu_unavailable) + # @unittest.skipIf(*gpu_unavailable) def test_backward_optimizers_adagrad( # noqa C901 self, T: int, @@ -1247,6 +1246,79 @@ def test_backward_optimizers_adagrad( # noqa C901 counter_halflife=counter_halflife, ) + @given( + T=st.integers(min_value=1, max_value=5), + D=st.integers(min_value=2, max_value=256), + B=st.integers(min_value=1, max_value=128), + log_E=st.integers(min_value=3, max_value=5), + L=st.integers(min_value=2, max_value=20), + weighted=st.booleans(), + mixed=st.booleans(), + mixed_B=st.booleans(), + long_segments=st.booleans(), + pooling_mode=st.sampled_from( + [ + PoolingMode.SUM, + PoolingMode.MEAN, + PoolingMode.NONE, + ] + ), + weight_decay_mode=st.sampled_from( + [ + WeightDecayMode.COUNTER, + WeightDecayMode.COWCLIP, + ] + ), + counter_weight_decay_mode=st.sampled_from( + [ + CounterWeightDecayMode.NONE, + CounterWeightDecayMode.L2, + CounterWeightDecayMode.DECOUPLE, + CounterWeightDecayMode.ADAGRADW, + ] + ), + ) + @settings( + verbosity=VERBOSITY, + max_examples=MAX_EXAMPLES_LONG_RUNNING, + deadline=None, + suppress_health_check=[HealthCheck.filter_too_much, HealthCheck.data_too_large], + ) + # @unittest.skipIf(*gpu_unavailable) + def test_backward_optimizers_adagrad_with_counter_cpu( # noqa C901 + self, + T: int, + D: int, + B: int, + log_E: int, + L: int, + weighted: bool, + mixed: bool, + mixed_B: bool, + long_segments: bool, + pooling_mode: PoolingMode, + weight_decay_mode: WeightDecayMode, + counter_weight_decay_mode: CounterWeightDecayMode, + ) -> None: + if pooling_mode == PoolingMode.NONE: + mixed_B = False + self.execute_backward_optimizers_( + T, + D, + B, + log_E, + L, + weighted, + mixed, + mixed_B, + OptimType.EXACT_ROWWISE_ADAGRAD, + long_segments, + pooling_mode, + True, # use_cpu + weight_decay_mode, + counter_weight_decay_mode=counter_weight_decay_mode, + ) + @given( T=st.integers(min_value=1, max_value=5), D=st.sampled_from([16, 32, 40, 48, 64]),