Skip to content

Commit f30d167

Browse files
authored
WOQ: fix phi3 issue with latest DeepSpeed and TP=6 (#3655)
1 parent 002f88d commit f30d167

File tree

1 file changed

+195
-8
lines changed

1 file changed

+195
-8
lines changed

intel_extension_for_pytorch/llm/utils.py

Lines changed: 195 additions & 8 deletions
Original file line numberDiff line numberDiff line change
@@ -912,8 +912,9 @@ def shard_low_precision_checkpoint(
912912
"""
913913
assert tp_grain_size % 8 == 0, "tp_grain_size must be a multiple of 8"
914914
num_heads = model_config["num_attention_heads"]
915+
num_kv_heads = num_heads
915916
if "num_key_value_heads" in model_config:
916-
num_heads = model_config["num_key_value_heads"]
917+
num_kv_heads = model_config["num_key_value_heads"]
917918
local_rank = rank
918919

919920
mha_layers_split_by_N = [
@@ -923,6 +924,9 @@ def shard_low_precision_checkpoint(
923924
"q_b_proj",
924925
"kv_b_proj",
925926
]
927+
qkv_proj_layers = [
928+
"qkv_proj",
929+
]
926930
# mlp is split with grain size = tp_grain_size
927931
mlp_layers_split_by_N = [
928932
"gate_proj",
@@ -933,6 +937,9 @@ def shard_low_precision_checkpoint(
933937
"w1",
934938
"w3",
935939
]
940+
gate_up_proj_layers = [
941+
"gate_up_proj",
942+
]
936943
mha_layers_split_by_K = [
937944
"o_proj",
938945
"out_proj",
@@ -947,20 +954,28 @@ def shard_low_precision_checkpoint(
947954
"w2",
948955
]
949956
lm_head_layers = ["lm_head"] # split by K but not quantized
957+
958+
def _key_belongs_to(key, layer_group):
959+
key_split = key.split(".")
960+
for layer in layer_group:
961+
if layer in key_split:
962+
return True
963+
return False
964+
950965
low_precision_checkpoint_dict = low_precision_checkpoint.copy()
951966
head_range = [0]
952-
head_per_rank = num_heads // world_size
967+
head_per_rank = num_kv_heads // world_size
953968
for i in range(0, world_size):
954969
head_this_rank = head_per_rank
955-
if i < num_heads % world_size:
970+
if i < num_kv_heads % world_size:
956971
head_this_rank += 1
957972
head_range.append(head_range[-1] + head_this_rank)
958973
for key in low_precision_checkpoint.keys():
959974
q_head_start = head_range[rank]
960975
q_head_end = q_head_start + (head_range[rank + 1] - head_range[rank])
961976
if "bias" in key:
962977
continue
963-
if any(substring in key for substring in mha_layers_split_by_N):
978+
if _key_belongs_to(key, mha_layers_split_by_N):
964979
data = low_precision_checkpoint_dict[key]
965980
if quantization_method == "awq":
966981
# qweight shape: [K, N // 8]
@@ -1041,7 +1056,91 @@ def shard_low_precision_checkpoint(
10411056
].contiguous()
10421057
else:
10431058
raise AssertionError(f"{quantization_method} is not supported yet.")
1044-
elif any(substring in key for substring in mlp_layers_split_by_N):
1059+
elif _key_belongs_to(key, qkv_proj_layers):
1060+
# need to split q, k and v proj then shard them separately
1061+
# finally concat them together
1062+
# mha layer split by N
1063+
data = low_precision_checkpoint_dict[key]
1064+
hidden_size = model_config["hidden_size"]
1065+
head_dim = hidden_size // num_heads
1066+
if quantization_method == "awq":
1067+
# qweight shape: [K, N // 8]
1068+
# scales shape: [K // G, N]
1069+
# qzeros shape: [K // G, N // 8]
1070+
N_pack_factor = 1 if "scales" in key else 8
1071+
N = data.shape[-1] * N_pack_factor
1072+
q_pos = N - 2 * num_kv_heads * head_dim
1073+
k_pos = q_pos + num_kv_heads * head_dim
1074+
v_pos = k_pos + num_kv_heads * head_dim
1075+
q_pos //= N_pack_factor
1076+
k_pos //= N_pack_factor
1077+
v_pos //= N_pack_factor
1078+
data_list = [
1079+
data[:, :q_pos],
1080+
data[:, q_pos:k_pos],
1081+
data[:, k_pos:v_pos],
1082+
]
1083+
for i in range(len(data_list)):
1084+
data = data_list[i].contiguous()
1085+
if data.shape[-1] % head_range[-1] == 0:
1086+
dim = data.shape[-1] // head_range[-1]
1087+
else:
1088+
assert data.shape[-1] % world_size == 0
1089+
dim = data.shape[-1] // world_size
1090+
q_head_start = local_rank
1091+
q_head_end = local_rank + 1
1092+
data_list[i] = data[
1093+
:, q_head_start * dim : q_head_end * dim
1094+
].contiguous()
1095+
low_precision_checkpoint_dict[key] = torch.cat(
1096+
data_list, dim=-1
1097+
).contiguous()
1098+
elif quantization_method == "gptq" or (
1099+
quantization_method == "rtn" and bits == 4
1100+
):
1101+
# qweight shape: [K // 8, N]
1102+
# scales shape: [K // G, N]
1103+
# qzeros shape: [K // G, N // 8]
1104+
# g_idx shape: [K]
1105+
data_list = []
1106+
if "g_idx" not in key:
1107+
N_pack_factor = 8 if "qzeros" in key else 1
1108+
N = data.shape[-1] * N_pack_factor
1109+
q_pos = N - 2 * num_kv_heads * head_dim
1110+
k_pos = q_pos + num_kv_heads * head_dim
1111+
v_pos = k_pos + num_kv_heads * head_dim
1112+
q_pos //= N_pack_factor
1113+
k_pos //= N_pack_factor
1114+
v_pos //= N_pack_factor
1115+
data_list = [
1116+
data[:, :q_pos],
1117+
data[:, q_pos:k_pos],
1118+
data[:, k_pos:v_pos],
1119+
]
1120+
for i in range(len(data_list)):
1121+
if "g_idx" in key:
1122+
continue
1123+
data = data_list[i]
1124+
if data.shape[-1] % head_range[-1] == 0:
1125+
dim = data.shape[-1] // head_range[-1]
1126+
else:
1127+
assert data.shape[-1] % world_size == 0
1128+
dim = data.shape[-1] // world_size
1129+
q_head_start = local_rank
1130+
q_head_end = local_rank + 1
1131+
data_list[i] = data[
1132+
:, q_head_start * dim : q_head_end * dim
1133+
].contiguous()
1134+
if "g_idx" in key:
1135+
if not desc_act:
1136+
low_precision_checkpoint_dict.pop(key)
1137+
else:
1138+
low_precision_checkpoint_dict[key] = torch.cat(
1139+
data_list, dim=-1
1140+
).contiguous()
1141+
else:
1142+
raise AssertionError(f"{quantization_method} is not supported yet.")
1143+
elif _key_belongs_to(key, mlp_layers_split_by_N):
10451144
data = low_precision_checkpoint_dict[key]
10461145
if quantization_method == "awq":
10471146
# qweight shape: [K, N // 8]
@@ -1178,7 +1277,95 @@ def shard_low_precision_checkpoint(
11781277
].contiguous()
11791278
else:
11801279
raise AssertionError(f"{quantization_method} is not supported yet.")
1181-
elif any(substring in key for substring in mha_layers_split_by_K):
1280+
elif _key_belongs_to(key, gate_up_proj_layers):
1281+
# need to split gate and up proj then shard them separately
1282+
# finally concat them together
1283+
# mlp layer split by N
1284+
data = low_precision_checkpoint_dict[key]
1285+
if quantization_method == "awq":
1286+
# qweight shape: [K, N // 8]
1287+
# scales shape: [K // G, N]
1288+
# qzeros shape: [K // G, N // 8]
1289+
data_list = list(data.chunk(2, dim=-1))
1290+
for i in range(len(data_list)):
1291+
data = data_list[i].contiguous()
1292+
if "scales" in key:
1293+
assert (
1294+
data.shape[1] % tp_grain_size == 0
1295+
), "N must be divisible by tp_grain_size"
1296+
grains = data.shape[1] // tp_grain_size
1297+
dim = tp_grain_size
1298+
else:
1299+
assert (
1300+
data.shape[1] * 8
1301+
) % tp_grain_size == 0, "N must be divisible by tp_grain_size"
1302+
grains = data.shape[1] // (tp_grain_size // 8)
1303+
dim = tp_grain_size // 8
1304+
grains_per_rank = grains // world_size
1305+
grains_rem = grains % world_size
1306+
grains_start = grains_per_rank * local_rank + min(
1307+
local_rank, grains_rem
1308+
)
1309+
grains_end = (
1310+
grains_start
1311+
+ grains_per_rank
1312+
+ (1 if local_rank < grains_rem else 0)
1313+
)
1314+
data_list[i] = data[
1315+
:, grains_start * dim : grains_end * dim
1316+
].contiguous()
1317+
low_precision_checkpoint_dict[key] = torch.cat(
1318+
data_list, dim=-1
1319+
).contiguous()
1320+
elif quantization_method == "gptq" or (
1321+
quantization_method == "rtn" and bits == 4
1322+
):
1323+
# qweight shape: [K // 8, N]
1324+
# scales shape: [K // G, N]
1325+
# qzeros shape: [K // G, N // 8]
1326+
# g_idx shape: [K]
1327+
data_list = list(data.chunk(2, dim=-1))
1328+
for i in range(len(data_list)):
1329+
if "g_idx" in key:
1330+
continue
1331+
data = data_list[i]
1332+
if "qzeros" in key:
1333+
assert (
1334+
data.shape[-1] * 8
1335+
) % tp_grain_size == 0, "N must be divisible by tp_grain_size"
1336+
grains = data.shape[-1] // (tp_grain_size // 8)
1337+
dim = tp_grain_size // 8
1338+
elif "g_idx" not in key: # qweight, scales
1339+
assert (
1340+
data.shape[-1] % tp_grain_size == 0
1341+
), "N must be divisible by tp_grain_size"
1342+
grains = data.shape[-1] // tp_grain_size
1343+
dim = tp_grain_size
1344+
grains_per_rank = grains // world_size
1345+
grains_rem = grains % world_size
1346+
grains_start = grains_per_rank * local_rank + min(
1347+
local_rank, grains_rem
1348+
)
1349+
grains_end = (
1350+
grains_start
1351+
+ grains_per_rank
1352+
+ (1 if local_rank < grains_rem else 0)
1353+
)
1354+
data_list[i] = data[
1355+
:, grains_start * dim : grains_end * dim
1356+
].contiguous()
1357+
if "g_idx" in key:
1358+
if not desc_act:
1359+
low_precision_checkpoint_dict.pop(key)
1360+
else:
1361+
low_precision_checkpoint_dict[key] = torch.cat(
1362+
data_list, dim=-1
1363+
).contiguous()
1364+
else:
1365+
raise AssertionError(f"{quantization_method} is not supported yet.")
1366+
elif _key_belongs_to(key, mha_layers_split_by_K):
1367+
if "bias" in key:
1368+
continue
11821369
data = low_precision_checkpoint_dict[key]
11831370
if ("scales" in key or "qzeros" in key) and data.shape[0] == 1:
11841371
continue
@@ -1269,7 +1456,7 @@ def shard_low_precision_checkpoint(
12691456
]
12701457
else:
12711458
raise AssertionError(f"{quantization_method} is not supported yet.")
1272-
elif any(substring in key for substring in mlp_layers_split_by_K):
1459+
elif _key_belongs_to(key, mlp_layers_split_by_K):
12731460
data = low_precision_checkpoint_dict[key]
12741461
if ("scales" in key or "qzeros" in key) and data.shape[0] == 1:
12751462
continue
@@ -1422,7 +1609,7 @@ def shard_low_precision_checkpoint(
14221609
]
14231610
else:
14241611
raise AssertionError(f"{quantization_method} is not supported yet.")
1425-
elif any(substring in key for substring in lm_head_layers):
1612+
elif _key_belongs_to(key, lm_head_layers):
14261613
# lm_head: [N, K] (not quantized)
14271614
# Same for all quantization methods
14281615
data = low_precision_checkpoint_dict[key]

0 commit comments

Comments
 (0)