@@ -912,8 +912,9 @@ def shard_low_precision_checkpoint(
912
912
"""
913
913
assert tp_grain_size % 8 == 0 , "tp_grain_size must be a multiple of 8"
914
914
num_heads = model_config ["num_attention_heads" ]
915
+ num_kv_heads = num_heads
915
916
if "num_key_value_heads" in model_config :
916
- num_heads = model_config ["num_key_value_heads" ]
917
+ num_kv_heads = model_config ["num_key_value_heads" ]
917
918
local_rank = rank
918
919
919
920
mha_layers_split_by_N = [
@@ -923,6 +924,9 @@ def shard_low_precision_checkpoint(
923
924
"q_b_proj" ,
924
925
"kv_b_proj" ,
925
926
]
927
+ qkv_proj_layers = [
928
+ "qkv_proj" ,
929
+ ]
926
930
# mlp is split with grain size = tp_grain_size
927
931
mlp_layers_split_by_N = [
928
932
"gate_proj" ,
@@ -933,6 +937,9 @@ def shard_low_precision_checkpoint(
933
937
"w1" ,
934
938
"w3" ,
935
939
]
940
+ gate_up_proj_layers = [
941
+ "gate_up_proj" ,
942
+ ]
936
943
mha_layers_split_by_K = [
937
944
"o_proj" ,
938
945
"out_proj" ,
@@ -947,20 +954,28 @@ def shard_low_precision_checkpoint(
947
954
"w2" ,
948
955
]
949
956
lm_head_layers = ["lm_head" ] # split by K but not quantized
957
+
958
+ def _key_belongs_to (key , layer_group ):
959
+ key_split = key .split ("." )
960
+ for layer in layer_group :
961
+ if layer in key_split :
962
+ return True
963
+ return False
964
+
950
965
low_precision_checkpoint_dict = low_precision_checkpoint .copy ()
951
966
head_range = [0 ]
952
- head_per_rank = num_heads // world_size
967
+ head_per_rank = num_kv_heads // world_size
953
968
for i in range (0 , world_size ):
954
969
head_this_rank = head_per_rank
955
- if i < num_heads % world_size :
970
+ if i < num_kv_heads % world_size :
956
971
head_this_rank += 1
957
972
head_range .append (head_range [- 1 ] + head_this_rank )
958
973
for key in low_precision_checkpoint .keys ():
959
974
q_head_start = head_range [rank ]
960
975
q_head_end = q_head_start + (head_range [rank + 1 ] - head_range [rank ])
961
976
if "bias" in key :
962
977
continue
963
- if any ( substring in key for substring in mha_layers_split_by_N ):
978
+ if _key_belongs_to ( key , mha_layers_split_by_N ):
964
979
data = low_precision_checkpoint_dict [key ]
965
980
if quantization_method == "awq" :
966
981
# qweight shape: [K, N // 8]
@@ -1041,7 +1056,91 @@ def shard_low_precision_checkpoint(
1041
1056
].contiguous ()
1042
1057
else :
1043
1058
raise AssertionError (f"{ quantization_method } is not supported yet." )
1044
- elif any (substring in key for substring in mlp_layers_split_by_N ):
1059
+ elif _key_belongs_to (key , qkv_proj_layers ):
1060
+ # need to split q, k and v proj then shard them separately
1061
+ # finally concat them together
1062
+ # mha layer split by N
1063
+ data = low_precision_checkpoint_dict [key ]
1064
+ hidden_size = model_config ["hidden_size" ]
1065
+ head_dim = hidden_size // num_heads
1066
+ if quantization_method == "awq" :
1067
+ # qweight shape: [K, N // 8]
1068
+ # scales shape: [K // G, N]
1069
+ # qzeros shape: [K // G, N // 8]
1070
+ N_pack_factor = 1 if "scales" in key else 8
1071
+ N = data .shape [- 1 ] * N_pack_factor
1072
+ q_pos = N - 2 * num_kv_heads * head_dim
1073
+ k_pos = q_pos + num_kv_heads * head_dim
1074
+ v_pos = k_pos + num_kv_heads * head_dim
1075
+ q_pos //= N_pack_factor
1076
+ k_pos //= N_pack_factor
1077
+ v_pos //= N_pack_factor
1078
+ data_list = [
1079
+ data [:, :q_pos ],
1080
+ data [:, q_pos :k_pos ],
1081
+ data [:, k_pos :v_pos ],
1082
+ ]
1083
+ for i in range (len (data_list )):
1084
+ data = data_list [i ].contiguous ()
1085
+ if data .shape [- 1 ] % head_range [- 1 ] == 0 :
1086
+ dim = data .shape [- 1 ] // head_range [- 1 ]
1087
+ else :
1088
+ assert data .shape [- 1 ] % world_size == 0
1089
+ dim = data .shape [- 1 ] // world_size
1090
+ q_head_start = local_rank
1091
+ q_head_end = local_rank + 1
1092
+ data_list [i ] = data [
1093
+ :, q_head_start * dim : q_head_end * dim
1094
+ ].contiguous ()
1095
+ low_precision_checkpoint_dict [key ] = torch .cat (
1096
+ data_list , dim = - 1
1097
+ ).contiguous ()
1098
+ elif quantization_method == "gptq" or (
1099
+ quantization_method == "rtn" and bits == 4
1100
+ ):
1101
+ # qweight shape: [K // 8, N]
1102
+ # scales shape: [K // G, N]
1103
+ # qzeros shape: [K // G, N // 8]
1104
+ # g_idx shape: [K]
1105
+ data_list = []
1106
+ if "g_idx" not in key :
1107
+ N_pack_factor = 8 if "qzeros" in key else 1
1108
+ N = data .shape [- 1 ] * N_pack_factor
1109
+ q_pos = N - 2 * num_kv_heads * head_dim
1110
+ k_pos = q_pos + num_kv_heads * head_dim
1111
+ v_pos = k_pos + num_kv_heads * head_dim
1112
+ q_pos //= N_pack_factor
1113
+ k_pos //= N_pack_factor
1114
+ v_pos //= N_pack_factor
1115
+ data_list = [
1116
+ data [:, :q_pos ],
1117
+ data [:, q_pos :k_pos ],
1118
+ data [:, k_pos :v_pos ],
1119
+ ]
1120
+ for i in range (len (data_list )):
1121
+ if "g_idx" in key :
1122
+ continue
1123
+ data = data_list [i ]
1124
+ if data .shape [- 1 ] % head_range [- 1 ] == 0 :
1125
+ dim = data .shape [- 1 ] // head_range [- 1 ]
1126
+ else :
1127
+ assert data .shape [- 1 ] % world_size == 0
1128
+ dim = data .shape [- 1 ] // world_size
1129
+ q_head_start = local_rank
1130
+ q_head_end = local_rank + 1
1131
+ data_list [i ] = data [
1132
+ :, q_head_start * dim : q_head_end * dim
1133
+ ].contiguous ()
1134
+ if "g_idx" in key :
1135
+ if not desc_act :
1136
+ low_precision_checkpoint_dict .pop (key )
1137
+ else :
1138
+ low_precision_checkpoint_dict [key ] = torch .cat (
1139
+ data_list , dim = - 1
1140
+ ).contiguous ()
1141
+ else :
1142
+ raise AssertionError (f"{ quantization_method } is not supported yet." )
1143
+ elif _key_belongs_to (key , mlp_layers_split_by_N ):
1045
1144
data = low_precision_checkpoint_dict [key ]
1046
1145
if quantization_method == "awq" :
1047
1146
# qweight shape: [K, N // 8]
@@ -1178,7 +1277,95 @@ def shard_low_precision_checkpoint(
1178
1277
].contiguous ()
1179
1278
else :
1180
1279
raise AssertionError (f"{ quantization_method } is not supported yet." )
1181
- elif any (substring in key for substring in mha_layers_split_by_K ):
1280
+ elif _key_belongs_to (key , gate_up_proj_layers ):
1281
+ # need to split gate and up proj then shard them separately
1282
+ # finally concat them together
1283
+ # mlp layer split by N
1284
+ data = low_precision_checkpoint_dict [key ]
1285
+ if quantization_method == "awq" :
1286
+ # qweight shape: [K, N // 8]
1287
+ # scales shape: [K // G, N]
1288
+ # qzeros shape: [K // G, N // 8]
1289
+ data_list = list (data .chunk (2 , dim = - 1 ))
1290
+ for i in range (len (data_list )):
1291
+ data = data_list [i ].contiguous ()
1292
+ if "scales" in key :
1293
+ assert (
1294
+ data .shape [1 ] % tp_grain_size == 0
1295
+ ), "N must be divisible by tp_grain_size"
1296
+ grains = data .shape [1 ] // tp_grain_size
1297
+ dim = tp_grain_size
1298
+ else :
1299
+ assert (
1300
+ data .shape [1 ] * 8
1301
+ ) % tp_grain_size == 0 , "N must be divisible by tp_grain_size"
1302
+ grains = data .shape [1 ] // (tp_grain_size // 8 )
1303
+ dim = tp_grain_size // 8
1304
+ grains_per_rank = grains // world_size
1305
+ grains_rem = grains % world_size
1306
+ grains_start = grains_per_rank * local_rank + min (
1307
+ local_rank , grains_rem
1308
+ )
1309
+ grains_end = (
1310
+ grains_start
1311
+ + grains_per_rank
1312
+ + (1 if local_rank < grains_rem else 0 )
1313
+ )
1314
+ data_list [i ] = data [
1315
+ :, grains_start * dim : grains_end * dim
1316
+ ].contiguous ()
1317
+ low_precision_checkpoint_dict [key ] = torch .cat (
1318
+ data_list , dim = - 1
1319
+ ).contiguous ()
1320
+ elif quantization_method == "gptq" or (
1321
+ quantization_method == "rtn" and bits == 4
1322
+ ):
1323
+ # qweight shape: [K // 8, N]
1324
+ # scales shape: [K // G, N]
1325
+ # qzeros shape: [K // G, N // 8]
1326
+ # g_idx shape: [K]
1327
+ data_list = list (data .chunk (2 , dim = - 1 ))
1328
+ for i in range (len (data_list )):
1329
+ if "g_idx" in key :
1330
+ continue
1331
+ data = data_list [i ]
1332
+ if "qzeros" in key :
1333
+ assert (
1334
+ data .shape [- 1 ] * 8
1335
+ ) % tp_grain_size == 0 , "N must be divisible by tp_grain_size"
1336
+ grains = data .shape [- 1 ] // (tp_grain_size // 8 )
1337
+ dim = tp_grain_size // 8
1338
+ elif "g_idx" not in key : # qweight, scales
1339
+ assert (
1340
+ data .shape [- 1 ] % tp_grain_size == 0
1341
+ ), "N must be divisible by tp_grain_size"
1342
+ grains = data .shape [- 1 ] // tp_grain_size
1343
+ dim = tp_grain_size
1344
+ grains_per_rank = grains // world_size
1345
+ grains_rem = grains % world_size
1346
+ grains_start = grains_per_rank * local_rank + min (
1347
+ local_rank , grains_rem
1348
+ )
1349
+ grains_end = (
1350
+ grains_start
1351
+ + grains_per_rank
1352
+ + (1 if local_rank < grains_rem else 0 )
1353
+ )
1354
+ data_list [i ] = data [
1355
+ :, grains_start * dim : grains_end * dim
1356
+ ].contiguous ()
1357
+ if "g_idx" in key :
1358
+ if not desc_act :
1359
+ low_precision_checkpoint_dict .pop (key )
1360
+ else :
1361
+ low_precision_checkpoint_dict [key ] = torch .cat (
1362
+ data_list , dim = - 1
1363
+ ).contiguous ()
1364
+ else :
1365
+ raise AssertionError (f"{ quantization_method } is not supported yet." )
1366
+ elif _key_belongs_to (key , mha_layers_split_by_K ):
1367
+ if "bias" in key :
1368
+ continue
1182
1369
data = low_precision_checkpoint_dict [key ]
1183
1370
if ("scales" in key or "qzeros" in key ) and data .shape [0 ] == 1 :
1184
1371
continue
@@ -1269,7 +1456,7 @@ def shard_low_precision_checkpoint(
1269
1456
]
1270
1457
else :
1271
1458
raise AssertionError (f"{ quantization_method } is not supported yet." )
1272
- elif any ( substring in key for substring in mlp_layers_split_by_K ):
1459
+ elif _key_belongs_to ( key , mlp_layers_split_by_K ):
1273
1460
data = low_precision_checkpoint_dict [key ]
1274
1461
if ("scales" in key or "qzeros" in key ) and data .shape [0 ] == 1 :
1275
1462
continue
@@ -1422,7 +1609,7 @@ def shard_low_precision_checkpoint(
1422
1609
]
1423
1610
else :
1424
1611
raise AssertionError (f"{ quantization_method } is not supported yet." )
1425
- elif any ( substring in key for substring in lm_head_layers ):
1612
+ elif _key_belongs_to ( key , lm_head_layers ):
1426
1613
# lm_head: [N, K] (not quantized)
1427
1614
# Same for all quantization methods
1428
1615
data = low_precision_checkpoint_dict [key ]
0 commit comments