Skip to content

Commit aabcb14

Browse files
spcypptmeta-codesync[bot]
authored andcommitted
Add CPU support for rowwise adagrad with counter (#4986)
Summary: Pull Request resolved: #4986 X-link: https://github.com/facebookresearch/FBGEMM/pull/1999 Add CPU support for rowwise adagrad with counter. This is to unblock MAI v3 on MTIA. Reviewed By: nautsimon, q10 Differential Revision: D81998586 fbshipit-source-id: eb9ddb4003443981a31a043b0b8257a381cf06c8
1 parent e22df7f commit aabcb14

File tree

3 files changed

+77
-5
lines changed

3 files changed

+77
-5
lines changed

fbgemm_gpu/cmake/tbe_sources.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -12,6 +12,7 @@
1212
"adagrad",
1313
"rowwise_adagrad",
1414
"sgd",
15+
"rowwise_adagrad_with_counter",
1516
]
1617

1718
# To be populated in the subsequent diffs
@@ -24,7 +25,6 @@
2425
"partial_rowwise_adam",
2526
"partial_rowwise_lamb",
2627
"none",
27-
"rowwise_adagrad_with_counter",
2828
]
2929

3030
DEPRECATED_OPTIMIZERS = [

fbgemm_gpu/codegen/genscript/optimizers.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -652,7 +652,7 @@ def rowwise_adagrad_with_counter() -> Dict[str, Any]:
652652
"split_weight_update": split_weight_update,
653653
"split_post_update": "",
654654
"split_weight_update_cpu": split_weight_update_cpu,
655-
"has_cpu_support": False,
655+
"has_cpu_support": True,
656656
"has_gpu_support": True,
657657
"has_vbe_support": True,
658658
"has_global_weight_decay_support": False,

fbgemm_gpu/test/tbe/training/backward_optimizers_test.py

Lines changed: 75 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -134,8 +134,7 @@ def execute_backward_optimizers_( # noqa C901
134134
]
135135
)
136136
and (
137-
use_cpu
138-
or optimizer != OptimType.EXACT_ROWWISE_ADAGRAD
137+
optimizer != OptimType.EXACT_ROWWISE_ADAGRAD
139138
or weight_decay_mode
140139
not in [
141140
WeightDecayMode.COUNTER,
@@ -1205,7 +1204,7 @@ def test_backward_optimizers_partial_rowwise_adam_bf16_momentum( # noqa C901
12051204
deadline=None,
12061205
suppress_health_check=[HealthCheck.filter_too_much, HealthCheck.data_too_large],
12071206
)
1208-
@unittest.skipIf(*gpu_unavailable)
1207+
# @unittest.skipIf(*gpu_unavailable)
12091208
def test_backward_optimizers_adagrad( # noqa C901
12101209
self,
12111210
T: int,
@@ -1247,6 +1246,79 @@ def test_backward_optimizers_adagrad( # noqa C901
12471246
counter_halflife=counter_halflife,
12481247
)
12491248

1249+
@given(
1250+
T=st.integers(min_value=1, max_value=5),
1251+
D=st.integers(min_value=2, max_value=256),
1252+
B=st.integers(min_value=1, max_value=128),
1253+
log_E=st.integers(min_value=3, max_value=5),
1254+
L=st.integers(min_value=2, max_value=20),
1255+
weighted=st.booleans(),
1256+
mixed=st.booleans(),
1257+
mixed_B=st.booleans(),
1258+
long_segments=st.booleans(),
1259+
pooling_mode=st.sampled_from(
1260+
[
1261+
PoolingMode.SUM,
1262+
PoolingMode.MEAN,
1263+
PoolingMode.NONE,
1264+
]
1265+
),
1266+
weight_decay_mode=st.sampled_from(
1267+
[
1268+
WeightDecayMode.COUNTER,
1269+
WeightDecayMode.COWCLIP,
1270+
]
1271+
),
1272+
counter_weight_decay_mode=st.sampled_from(
1273+
[
1274+
CounterWeightDecayMode.NONE,
1275+
CounterWeightDecayMode.L2,
1276+
CounterWeightDecayMode.DECOUPLE,
1277+
CounterWeightDecayMode.ADAGRADW,
1278+
]
1279+
),
1280+
)
1281+
@settings(
1282+
verbosity=VERBOSITY,
1283+
max_examples=MAX_EXAMPLES_LONG_RUNNING,
1284+
deadline=None,
1285+
suppress_health_check=[HealthCheck.filter_too_much, HealthCheck.data_too_large],
1286+
)
1287+
# @unittest.skipIf(*gpu_unavailable)
1288+
def test_backward_optimizers_adagrad_with_counter_cpu( # noqa C901
1289+
self,
1290+
T: int,
1291+
D: int,
1292+
B: int,
1293+
log_E: int,
1294+
L: int,
1295+
weighted: bool,
1296+
mixed: bool,
1297+
mixed_B: bool,
1298+
long_segments: bool,
1299+
pooling_mode: PoolingMode,
1300+
weight_decay_mode: WeightDecayMode,
1301+
counter_weight_decay_mode: CounterWeightDecayMode,
1302+
) -> None:
1303+
if pooling_mode == PoolingMode.NONE:
1304+
mixed_B = False
1305+
self.execute_backward_optimizers_(
1306+
T,
1307+
D,
1308+
B,
1309+
log_E,
1310+
L,
1311+
weighted,
1312+
mixed,
1313+
mixed_B,
1314+
OptimType.EXACT_ROWWISE_ADAGRAD,
1315+
long_segments,
1316+
pooling_mode,
1317+
True, # use_cpu
1318+
weight_decay_mode,
1319+
counter_weight_decay_mode=counter_weight_decay_mode,
1320+
)
1321+
12501322
@given(
12511323
T=st.integers(min_value=1, max_value=5),
12521324
D=st.sampled_from([16, 32, 40, 48, 64]),

0 commit comments

Comments
 (0)