Skip to content

Commit 21d7d20

Browse files
authored
fix (#10970)
1 parent 4d1b2a9 commit 21d7d20

File tree

2 files changed

+15
-5
lines changed

2 files changed

+15
-5
lines changed

paddlenlp/transformers/deepseek_v2/modeling.py

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -1023,6 +1023,7 @@ def __init__(self, config: DeepseekV2Config, norm_weight=None, norm_eps=None):
10231023
using_post_norm_recompute=self.using_post_norm_recompute,
10241024
norm_weight=norm_weight,
10251025
norm_eps=norm_eps,
1026+
recompute_fwd_gate_up=True,
10261027
)
10271028
else:
10281029
self.shared_experts = DeepseekV2MLPClass(

paddlenlp/transformers/fp8_utils.py

Lines changed: 14 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -685,7 +685,7 @@ def backward(ctx, do3):
685685

686686
class FP8MlpFunction(paddle.autograd.PyLayer):
687687
@staticmethod
688-
def forward(ctx, x, w1, w2):
688+
def forward(ctx, x, w1, w2, recompute_fwd_gate_up):
689689
# ===== reshape for deep_gemm, since deep_gemm only support 2D =====
690690
x_orig_shape = x.shape
691691
x = x.reshape([-1, x_orig_shape[-1]])
@@ -697,6 +697,7 @@ def forward(ctx, x, w1, w2):
697697
o3 = o3.reshape([x_orig_shape[0], -1, o3.shape[-1]])
698698

699699
# ===== save for backward =====
700+
o1 = None if recompute_fwd_gate_up else o1
700701
ctx.save_for_backward(
701702
o1,
702703
x_fp8,
@@ -729,9 +730,14 @@ def backward(ctx, do3):
729730
)
730731

731732
# ===== call func common_fp8_mlp_bwd =====
732-
dx = FP8LinearFunctionBase.common_fp8_mlp_bwd(
733-
do3, x_t_fp8, x_t_scale, w1, w2, o1=o1, x_fp8=None, x_scale=None, apply_backward_hook=True
734-
)
733+
if o1 is None:
734+
dx = FP8LinearFunctionBase.common_fp8_mlp_bwd(
735+
do3, x_t_fp8, x_t_scale, w1, w2, o1=None, x_fp8=x_fp8, x_scale=x_scale, apply_backward_hook=True
736+
)
737+
else:
738+
dx = FP8LinearFunctionBase.common_fp8_mlp_bwd(
739+
do3, x_t_fp8, x_t_scale, w1, w2, o1=o1, x_fp8=None, x_scale=None, apply_backward_hook=True
740+
)
735741
# ===== reshape to origin shape =====
736742
if len(x_orig_shape) > 2:
737743
dx = dx.reshape([x_orig_shape[0], -1, dx.shape[-1]])
@@ -749,6 +755,7 @@ def __init__(
749755
using_post_norm_recompute=False,
750756
norm_weight=None,
751757
norm_eps=None,
758+
recompute_fwd_gate_up=False,
752759
):
753760
super().__init__()
754761
self.config = config
@@ -761,6 +768,8 @@ def __init__(
761768
self.hidden_size = config.hidden_size if hidden_size is None else hidden_size
762769
self.intermediate_size = config.intermediate_size if intermediate_size is None else intermediate_size
763770

771+
self.recompute_fwd_gate_up = recompute_fwd_gate_up
772+
764773
self.w1 = self.create_parameter(
765774
shape=[self.hidden_size, self.intermediate_size * 2],
766775
dtype="bfloat16",
@@ -780,7 +789,7 @@ def forward(self, x):
780789
if self.using_post_norm_recompute:
781790
return FusedNormFP8MLPFunction.apply(x, self.norm_weight, self.w1, self.w2, self.norm_eps)
782791
else:
783-
return FP8MlpFunction.apply(x, self.w1, self.w2)
792+
return FP8MlpFunction.apply(x, self.w1, self.w2, self.recompute_fwd_gate_up)
784793

785794

786795
def split_group_gemm(x_fp8, x_scale, w_fp8, w_scale, tokens_per_expert, gemm_out):

0 commit comments

Comments
 (0)