Skip to content

Commit adc2f36

Browse files
authored
fix some bugs
1 parent 546e1cb commit adc2f36

File tree

3 files changed

+61
-23
lines changed

3 files changed

+61
-23
lines changed

paddlenlp/transformers/deepseek_v2/modeling.py

Lines changed: 25 additions & 16 deletions
Original file line numberDiff line numberDiff line change
@@ -33,8 +33,8 @@
3333
import paddle.nn.functional as F
3434
from paddle import Tensor, nn
3535
from paddle.distributed import fleet
36-
from paddle.distributed.fleet.meta_parallel import get_rng_state_tracker
3736
from paddle.distributed.communication.reduce import ReduceOp
37+
from paddle.distributed.fleet.meta_parallel import get_rng_state_tracker
3838
from paddle.distributed.fleet.recompute.recompute import recompute
3939
from paddle.nn import BCEWithLogitsLoss, CrossEntropyLoss, MSELoss
4040

@@ -799,7 +799,7 @@ def __init__(self, config: DeepseekV2Config):
799799

800800
for p in self.experts.parameters():
801801
setattr(p, "color", {"color": "moe_expert", "group": moe_grad_group})
802-
802+
setattr(p, "is_moe_param", True)
803803
self.alpha = config.aux_loss_alpha
804804
if config.n_shared_experts is not None:
805805
intermediate_size = config.moe_intermediate_size * config.n_shared_experts
@@ -851,6 +851,7 @@ def __init__(self, config: DeepseekV2Config):
851851

852852
for p in self.experts.parameters():
853853
setattr(p, "color", {"color": "moe_expert", "group": moe_grad_group})
854+
setattr(p, "is_moe_param", True)
854855

855856
self.alpha = config.aux_loss_alpha
856857
if config.n_shared_experts is not None:
@@ -895,7 +896,9 @@ def __init__(self, config: DeepseekV2Config, layerwise_recompute: bool = False):
895896
self.num_heads = config.num_attention_heads
896897
self.num_local_heads = self.num_heads
897898
if config.tensor_parallel_degree > 1:
898-
assert self.num_heads % config.tensor_parallel_degree == 0, f"Attention head num ({self.num_heads}) is not divisible by tensor_parallel_degree ({config.tensor_parallel_degree})."
899+
assert (
900+
self.num_heads % config.tensor_parallel_degree == 0
901+
), f"Attention head num ({self.num_heads}) is not divisible by tensor_parallel_degree ({config.tensor_parallel_degree})."
899902
self.num_local_heads = self.num_heads // config.tensor_parallel_degree
900903

901904
self.max_position_embeddings = config.max_position_embeddings
@@ -1067,7 +1070,12 @@ def forward(
10671070

10681071
if self.sequence_parallel:
10691072
target_query_shape = [bsz, self.seq_length, self.num_local_heads, self.q_head_dim]
1070-
target_key_value_shape = [bsz, self.seq_length, self.num_local_heads, self.qk_nope_head_dim + self.v_head_dim]
1073+
target_key_value_shape = [
1074+
bsz,
1075+
self.seq_length,
1076+
self.num_local_heads,
1077+
self.qk_nope_head_dim + self.v_head_dim,
1078+
]
10711079
else:
10721080
target_query_shape = [0, 0, self.num_heads, self.q_head_dim]
10731081
target_key_value_shape = [0, 0, self.num_heads, self.qk_nope_head_dim + self.v_head_dim]
@@ -1153,7 +1161,6 @@ def forward(
11531161
if attn_output.shape != ori_shape:
11541162
attn_output = attn_output.reshape(ori_shape)
11551163

1156-
11571164
if not output_attentions:
11581165
attn_weights = None
11591166

@@ -1511,7 +1518,7 @@ def forward(
15111518
hidden_states = self.hnorm(hidden_states)
15121519
nextn_hidden_state = self.enorm(nextn_hidden_state)
15131520

1514-
hidden_states = self.eh_proj(paddle.concat([hidden_states, nextn_hidden_state], axis=-1))
1521+
hidden_states = self.eh_proj(paddle.concat([nextn_hidden_state, hidden_states], axis=-1))
15151522

15161523
layer_outputs = super(DeepseekV2MTPLayer, self).forward(
15171524
hidden_states,
@@ -1711,10 +1718,13 @@ def get_tensor_parallel_split_mappings(num_layers):
17111718

17121719
return final_actions
17131720

1714-
mappings = get_tensor_parallel_split_mappings(config.num_hidden_layers)
1721+
mappings = get_tensor_parallel_split_mappings(config.num_hidden_layers + 2)
17151722

17161723
return mappings
17171724

1725+
def get_tensor_parallel_mappings(self, is_split=True):
1726+
return type(self)._get_tensor_parallel_mappings(self.config, is_split)
1727+
17181728
def _init_weights(self, layer):
17191729
return
17201730
if self.config.tensor_parallel_degree > 1:
@@ -1988,7 +1998,7 @@ def forward(
19881998
if self.config.sequence_parallel:
19891999
# [bs, seq_len, num_head * head_dim] -> [bs * seq_len, num_head * head_dim]
19902000
bs, seq_len, hidden_size = inputs_embeds.shape
1991-
inputs_embeds = paddle.transpose(inputs_embeds, [1, 0, 2]) # [B, S, H] --> [S, B, H]
2001+
inputs_embeds = paddle.transpose(inputs_embeds, [1, 0, 2]) # [B, S, H] --> [S, B, H]
19922002
# inputs_embeds = paddle.reshape(inputs_embeds, [bs * seq_len, hidden_size])
19932003
# [seq_len * bs / n, num_head * head_dim] (n is mp parallelism)
19942004
inputs_embeds = ScatterOp.apply(inputs_embeds)
@@ -2071,7 +2081,7 @@ def forward(
20712081

20722082
if self.config.sequence_parallel:
20732083
hidden_states = GatherOp.apply(hidden_states)
2074-
hidden_states = paddle.transpose(hidden_states, [1, 0, 2]) # [S, B, H] --> [B, S, H]
2084+
hidden_states = paddle.transpose(hidden_states, [1, 0, 2]) # [S, B, H] --> [B, S, H]
20752085
# hidden_states = hidden_states.reshape([-1, seq_length, hidden_states.shape[-1]])
20762086

20772087
inputs_embeds_cur_depth = paddle.concat(
@@ -2173,7 +2183,7 @@ def add_loss(main_loss, loss):
21732183
seq_length = masked_lm_labels.shape[1]
21742184

21752185
if self.config.sequence_parallel:
2176-
masked_lm_labels = masked_lm_labels.transpose([1, 0]) # [B, S] --> [S, B]
2186+
masked_lm_labels = masked_lm_labels.transpose([1, 0]) # [B, S] --> [S, B]
21772187
masked_lm_labels = ScatterOp.apply(masked_lm_labels)
21782188

21792189
loss = compute_loss(prediction_scores, masked_lm_labels)
@@ -2188,16 +2198,15 @@ def add_loss(main_loss, loss):
21882198
masked_lm_labels_cur_depth = masked_lm_labels_ori[:, (depth + 1) : (depth + 1 + seq_length)]
21892199

21902200
if self.config.sequence_parallel:
2191-
masked_lm_labels_cur_depth = masked_lm_labels_cur_depth.transpose([1, 0]) # [B, S] --> [S, B]
2201+
masked_lm_labels_cur_depth = masked_lm_labels_cur_depth.transpose([1, 0]) # [B, S] --> [S, B]
21922202
masked_lm_labels_cur_depth = ScatterOp.apply(masked_lm_labels_cur_depth)
21932203

21942204
res_cur_depth = compute_loss(prediction_scores_cur_depth, masked_lm_labels_cur_depth)
2195-
2205+
21962206
if self.config.sequence_parallel:
21972207
res_cur_depth = res_cur_depth * self.seq_para_scale
21982208
dist.all_reduce(res_cur_depth, op=ReduceOp.SUM, group=self.mp_group)
21992209

2200-
22012210
mtp_loss_res.append(res_cur_depth)
22022211
loss = add_loss(loss, self.config.num_nextn_predict_lambda * sum([x for x in mtp_loss_res]) / len(mtp_loss_res)) # fmt: skip
22032212

@@ -2245,9 +2254,9 @@ def __init__(self, config: DeepseekV2Config):
22452254
def forward(self, hidden_states, tensor_parallel_output=None):
22462255

22472256
# if self.config.sequence_parallel:
2248-
# hidden_states = GatherOp.apply(hidden_states)
2249-
# hidden_states = paddle.transpose(hidden_states, [1, 0, 2])
2250-
# hidden_states = paddle.reshape_(hidden_states, [-1, self.seq_length, self.config.hidden_size])
2257+
# hidden_states = GatherOp.apply(hidden_states)
2258+
# hidden_states = paddle.transpose(hidden_states, [1, 0, 2])
2259+
# hidden_states = paddle.reshape_(hidden_states, [-1, self.seq_length, self.config.hidden_size])
22512260

22522261
if tensor_parallel_output is None:
22532262
tensor_parallel_output = self.config.tensor_parallel_output

paddlenlp/transformers/moe_gate.py

Lines changed: 11 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -578,11 +578,19 @@ def topkgating_nodrop(self, gates: paddle.Tensor):
578578
# get topk mask
579579
mask = paddle.zeros_like(gates).put_along_axis(top_idx, paddle.to_tensor(1.0), axis=1)
580580

581+
# hongyu fix start
582+
gates_masked = gates * mask
583+
gates_s = paddle.sum(gates_masked, axis=-1, keepdim=True)
584+
denom_s = paddle.clip(gates_s, min=paddle.finfo(gates_masked.dtype).eps)
585+
586+
if self.norm_topk_prob:
587+
gates_masked = gates_masked / denom_s
588+
gates_masked *= self.routed_scaling_factor
589+
# hongyu fix end
581590
if hasattr(self.config, "seq_aux") and self.config.seq_aux:
582591
l_aux = self._cal_seq_aux_loss(gates_ori, self.top_k, top_idx)
583592
else:
584593
l_aux = self._cal_aux_loss(gates, mask)
585-
586594
exp_counts = paddle.sum(mask.cast(paddle.int64), axis=0)
587-
topk_masked_gates = paddle.zeros_like(gates).put_along_axis(top_idx, top_gate, axis=1)
588-
return topk_masked_gates, mask, exp_counts, l_aux, l_zloss
595+
# topk_masked_gates = paddle.zeros_like(gates).put_along_axis(top_idx, top_gate, axis=1)
596+
return gates_masked, mask, exp_counts, l_aux, l_zloss

paddlenlp/transformers/moe_layer.py

Lines changed: 25 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -175,7 +175,6 @@ def __init__(
175175
is_fleet_init = True
176176
except AttributeError:
177177
is_fleet_init = False
178-
179178
if is_fleet_init and dist.get_world_size() > 1:
180179
if moe_group == "data":
181180
self.moe_group = dist.fleet.get_hybrid_communicate_group().get_data_parallel_group()
@@ -198,7 +197,6 @@ def __init__(
198197
self.expert_parallel_degree = 1
199198
self.moe_num_experts_per_device = self.moe_num_experts
200199
self.is_dummy_moe = True
201-
202200
self.all_to_all_dropout = all_to_all_dropout
203201
self.enable_recompute = False
204202

@@ -348,21 +346,34 @@ def __init__(self, config, moe_num_experts, expert_class, expert_kwargs, gate, m
348346
self.moe_router_topk = gate.top_k
349347
self.moe_num_experts = moe_num_experts
350348
self.num_local_experts = moe_num_experts // self.ep_size
349+
self.moe_rank = dist.get_rank(self.moe_group)
350+
self.moe_rank = 0 if self.moe_rank < 0 else self.moe_rank
351351
self.token_dispatcher = MoEFlexTokenDispatcher(
352352
self.num_local_experts, self.moe_router_topk, self.moe_num_experts, moe_group
353353
)
354-
self.experts = nn.LayerList([expert_class(**expert_kwargs) for _ in range(self.num_local_experts)])
354+
self.expert_parallel_degree = 1 if self.ep_size < 0 else self.ep_size
355+
self.moe_num_experts_per_device = self._parse_moe_expert_parallel(
356+
self.moe_num_experts, self.expert_parallel_degree
357+
)
358+
self.experts = nn.LayerList([])
359+
for i in range(self.moe_num_experts):
360+
if i // self.moe_num_experts_per_device == self.moe_rank:
361+
self.experts.append(expert_class(**expert_kwargs))
362+
else:
363+
self.experts.append(None)
355364
self.router = gate
356365

357366
def expert_forward(self, dispatched_input, tokens_per_expert):
358367
outputs = []
359368
tokens_per_expert = tokens_per_expert.tolist()
360369
# print(f"all tokens: {sum(tokens_per_expert)}, detail: {tokens_per_expert}")
361370
chunks = paddle.split(dispatched_input, num_or_sections=tokens_per_expert, axis=0)
362-
for chunk, expert in zip(chunks, self.experts):
371+
for i, chunk in enumerate(chunks):
363372
chunk = chunk.contiguous()
364373
# assert chunk.shape[0] != 0, "Cannot dispatch empty input"
365374
# print("expert token:", chunk.shape, flush=True)
375+
# assert chunk.shape[0] != 0, "Cannot dispatch empty input"
376+
expert = self.experts[i + self.moe_rank * self.moe_num_experts_per_device]
366377
outputs += [expert(chunk)]
367378

368379
return paddle.concat(outputs, axis=0)
@@ -377,3 +388,13 @@ def forward(self, hidden_states: paddle.Tensor):
377388
expert_output = self.expert_forward(dispatched_input, tokens_per_expert)
378389
output, _ = self.token_dispatcher.token_unpermutation(expert_output, None)
379390
return output, l_aux, l_zloss
391+
392+
def _parse_moe_expert_parallel(self, moe_num_experts, expert_parallel_degree):
393+
assert (
394+
moe_num_experts >= expert_parallel_degree
395+
), f"expert moe_num_experts={moe_num_experts} >= moe_world_size={expert_parallel_degree}"
396+
assert (
397+
moe_num_experts % expert_parallel_degree == 0
398+
), f"expert moe_num_experts={moe_num_experts} % moe_world_size={expert_parallel_degree} == 0"
399+
moe_num_experts_per_device = moe_num_experts // expert_parallel_degree
400+
return moe_num_experts_per_device

0 commit comments

Comments
 (0)