Skip to content

Commit 4d1b2a9

Browse files
authored
Dense node supports send_mtp_embed (#10967)
1 parent 6e0feba commit 4d1b2a9

File tree

2 files changed

+25
-5
lines changed

2 files changed

+25
-5
lines changed

paddlenlp/transformers/deepseek_v2/modeling.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -2400,7 +2400,7 @@ def self_attn_compute(self, hidden_states, **kwargs):
24002400

24012401
residual = hidden_states
24022402

2403-
if not self.using_post_norm_recompute:
2403+
if not (self.using_post_norm_recompute and isinstance(self.mlp, DeepseekV2MoE)):
24042404
hidden_states = self.post_attention_layernorm(hidden_states)
24052405

24062406
return hidden_states, residual

paddlenlp/transformers/deepseek_v2/modeling_pp.py

Lines changed: 24 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -1167,6 +1167,8 @@ def forward_backward(self, inputs, output_grad, combine_bw_event_to_wait=None, p
11671167
paddle.base.core.nvprof_nvtx_push("dense_fw_moe_bw")
11681168

11691169
paddle.base.core.nvprof_nvtx_push("dense_attn_moe_combine")
1170+
# Note: the input combine_bw_event_to_wait is unreliable, we need to record a new event here.
1171+
combine_bw_event_to_wait = deep_ep.get_event_from_calc_stream(self.backward_node.moe_group.id)
11701172
output_grad = self.backward_node.post_process_backward(output_grad, combine_bw_event_to_wait)
11711173
output_grad = self.backward_node.combine_backward(
11721174
output_grad, previous_event=combine_bw_event_to_wait, async_finish=True, allocate_on_comm_stream=True
@@ -1625,13 +1627,31 @@ def attn_compute_dense(self, args):
16251627
assert attention_mask is None
16261628
assert attn_mask_startend_row_indices is None
16271629
assert position_ids is None
1628-
hidden_states, _ = self.self_attn_compute(hidden_states)
1629-
return hidden_states
16301630

1631-
def mlp_compute_dense(self, hidden_states):
1632-
residual = hidden_states
1631+
if self.config.send_mtp_embed:
1632+
batch_size, _, hidden_size = hidden_states.shape
1633+
batch_size_mtp = hidden_size // (self.config.num_nextn_predict_layers + 1)
1634+
inputs_embeds_mtp = hidden_states[..., batch_size_mtp:]
1635+
hidden_states = hidden_states[..., :batch_size_mtp]
1636+
1637+
hidden_states, residual = self.self_attn_compute(hidden_states)
1638+
1639+
ret = (hidden_states, residual)
1640+
ret = (inputs_embeds_mtp, *ret) if self.config.send_mtp_embed else ret
1641+
return ret
1642+
1643+
def mlp_compute_dense(self, inputs):
1644+
if self.config.send_mtp_embed:
1645+
(inputs_embeds_mtp, hidden_states, residual) = inputs
1646+
else:
1647+
(hidden_states, residual) = inputs
1648+
16331649
hidden_states = self.mlp(hidden_states)
16341650
hidden_states = residual + hidden_states
1651+
1652+
if self.config.send_mtp_embed:
1653+
hidden_states = paddle.concat([hidden_states, inputs_embeds_mtp], axis=-1)
1654+
16351655
return hidden_states
16361656

16371657
def build_schedule_node(self):

0 commit comments

Comments
 (0)