@@ -1167,6 +1167,8 @@ def forward_backward(self, inputs, output_grad, combine_bw_event_to_wait=None, p
1167
1167
paddle .base .core .nvprof_nvtx_push ("dense_fw_moe_bw" )
1168
1168
1169
1169
paddle .base .core .nvprof_nvtx_push ("dense_attn_moe_combine" )
1170
+ # Note: the input combine_bw_event_to_wait is unreliable, we need to record a new event here.
1171
+ combine_bw_event_to_wait = deep_ep .get_event_from_calc_stream (self .backward_node .moe_group .id )
1170
1172
output_grad = self .backward_node .post_process_backward (output_grad , combine_bw_event_to_wait )
1171
1173
output_grad = self .backward_node .combine_backward (
1172
1174
output_grad , previous_event = combine_bw_event_to_wait , async_finish = True , allocate_on_comm_stream = True
@@ -1625,13 +1627,31 @@ def attn_compute_dense(self, args):
1625
1627
assert attention_mask is None
1626
1628
assert attn_mask_startend_row_indices is None
1627
1629
assert position_ids is None
1628
- hidden_states , _ = self .self_attn_compute (hidden_states )
1629
- return hidden_states
1630
1630
1631
- def mlp_compute_dense (self , hidden_states ):
1632
- residual = hidden_states
1631
+ if self .config .send_mtp_embed :
1632
+ batch_size , _ , hidden_size = hidden_states .shape
1633
+ batch_size_mtp = hidden_size // (self .config .num_nextn_predict_layers + 1 )
1634
+ inputs_embeds_mtp = hidden_states [..., batch_size_mtp :]
1635
+ hidden_states = hidden_states [..., :batch_size_mtp ]
1636
+
1637
+ hidden_states , residual = self .self_attn_compute (hidden_states )
1638
+
1639
+ ret = (hidden_states , residual )
1640
+ ret = (inputs_embeds_mtp , * ret ) if self .config .send_mtp_embed else ret
1641
+ return ret
1642
+
1643
+ def mlp_compute_dense (self , inputs ):
1644
+ if self .config .send_mtp_embed :
1645
+ (inputs_embeds_mtp , hidden_states , residual ) = inputs
1646
+ else :
1647
+ (hidden_states , residual ) = inputs
1648
+
1633
1649
hidden_states = self .mlp (hidden_states )
1634
1650
hidden_states = residual + hidden_states
1651
+
1652
+ if self .config .send_mtp_embed :
1653
+ hidden_states = paddle .concat ([hidden_states , inputs_embeds_mtp ], axis = - 1 )
1654
+
1635
1655
return hidden_states
1636
1656
1637
1657
def build_schedule_node (self ):
0 commit comments