Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
16 changes: 15 additions & 1 deletion paddlenlp/transformers/deepseek_v2/modeling.py
Original file line number Diff line number Diff line change
Expand Up @@ -106,6 +106,7 @@ def swiglu(x, y=None):
x, y = paddle.chunk(x, chunks=2, axis=-1)
return F.silu(x) * y


try:
from paddle.incubate.nn.functional import fused_partial_rope
except ImportError:
Expand Down Expand Up @@ -752,6 +753,7 @@ def forward(self, x):

class FusedNormGateFunc(paddle.autograd.PyLayer):
"""recompute of postnorm and gate"""

_current_norm_output = None
_current_invar = None

Expand Down Expand Up @@ -799,6 +801,7 @@ def backward(ctx, d_gate_logits, d_norm_output):

return dx, d_rms_norm_weight, d_moe_gate_weight


class TemporaryVarContext:
def __init__(self, norm_output, invar):
self.norm_output = norm_output
Expand All @@ -810,6 +813,7 @@ def __enter__(self):
def __exit__(self, exc_type, exc_val, exc_tb):
FusedNormGateFunc.clear_temporary_vars()


def balance_expert_assignment(n, m, k):
assert k * n % m == 0
matrix = paddle.zeros((n, m), dtype=paddle.int32)
Expand Down Expand Up @@ -1014,6 +1018,7 @@ def __init__(self, config: DeepseekV2Config, norm_weight=None, norm_eps=None):
using_post_norm_recompute=self.using_post_norm_recompute,
norm_weight=norm_weight,
norm_eps=norm_eps,
recompute_fwd_gate_up=True,
)
else:
self.shared_experts = DeepseekV2MLPClass(
Expand Down Expand Up @@ -1171,7 +1176,16 @@ def qkv_pre_process(
):
if (fused_partial_rope is None) or (position_ids is not None):
return qkv_pre_process_no_fuse(
q, kv, k_pe, rotary_emb, num_heads, q_head_dim, qk_nope_head_dim, v_head_dim, qk_rope_head_dim, position_ids
q,
kv,
k_pe,
rotary_emb,
num_heads,
q_head_dim,
qk_nope_head_dim,
v_head_dim,
qk_rope_head_dim,
position_ids,
)

bsz, q_len, _ = q.shape
Expand Down
19 changes: 14 additions & 5 deletions paddlenlp/transformers/fp8_utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -660,7 +660,7 @@ def backward(ctx, do3):

class FP8MlpFunction(paddle.autograd.PyLayer):
@staticmethod
def forward(ctx, x, w1, w2):
def forward(ctx, x, w1, w2, recompute_fwd_gate_up):
# ===== reshape for deep_gemm, since deep_gemm only support 2D =====
x_orig_shape = x.shape
x = x.reshape([-1, x_orig_shape[-1]])
Expand All @@ -672,6 +672,7 @@ def forward(ctx, x, w1, w2):
o3 = o3.reshape([x_orig_shape[0], -1, o3.shape[-1]])

# ===== save for backward =====
o1 = None if recompute_fwd_gate_up else o1
ctx.save_for_backward(
o1,
x_fp8,
Expand Down Expand Up @@ -704,9 +705,14 @@ def backward(ctx, do3):
)

# ===== call func common_fp8_mlp_bwd =====
dx = FP8LinearFunctionBase.common_fp8_mlp_bwd(
do3, x_t_fp8, x_t_scale, w1, w2, o1=o1, x_fp8=None, x_scale=None, apply_backward_hook=True
)
if o1 is None:
dx = FP8LinearFunctionBase.common_fp8_mlp_bwd(
do3, x_t_fp8, x_t_scale, w1, w2, o1=None, x_fp8=x_fp8, x_scale=x_scale, apply_backward_hook=True
)
else:
dx = FP8LinearFunctionBase.common_fp8_mlp_bwd(
do3, x_t_fp8, x_t_scale, w1, w2, o1=o1, x_fp8=None, x_scale=None, apply_backward_hook=True
)
# ===== reshape to origin shape =====
if len(x_orig_shape) > 2:
dx = dx.reshape([x_orig_shape[0], -1, dx.shape[-1]])
Expand All @@ -724,6 +730,7 @@ def __init__(
using_post_norm_recompute=False,
norm_weight=None,
norm_eps=None,
recompute_fwd_gate_up=False,
):
super().__init__()
self.config = config
Expand All @@ -736,6 +743,8 @@ def __init__(
self.hidden_size = config.hidden_size if hidden_size is None else hidden_size
self.intermediate_size = config.intermediate_size if intermediate_size is None else intermediate_size

self.recompute_fwd_gate_up = recompute_fwd_gate_up

self.w1 = self.create_parameter(
shape=[self.hidden_size, self.intermediate_size * 2],
dtype="bfloat16",
Expand All @@ -755,7 +764,7 @@ def forward(self, x):
if self.using_post_norm_recompute:
return FusedNormFP8MLPFunction.apply(x, self.norm_weight, self.w1, self.w2, self.norm_eps)
else:
return FP8MlpFunction.apply(x, self.w1, self.w2)
return FP8MlpFunction.apply(x, self.w1, self.w2, self.recompute_fwd_gate_up)


def split_group_gemm(x_fp8, x_scale, w_fp8, w_scale, tokens_per_expert, gemm_out):
Expand Down
Loading