Skip to content

Commit 6e0feba

Browse files
authored
add quant cache (#10969)
1 parent 0e9051a commit 6e0feba

File tree

3 files changed

+61
-2
lines changed

3 files changed

+61
-2
lines changed

paddlenlp/trainer/trainer_callback.py

Lines changed: 19 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -27,6 +27,7 @@
2727
import numpy as np
2828
from tqdm.auto import tqdm
2929

30+
from paddlenlp.transformers.moe_utils import offload, reload
3031
from paddlenlp.utils.log import logger
3132

3233
from .trainer_utils import IntervalStrategy, has_length
@@ -646,5 +647,23 @@ def on_step_begin(self, args, state, control, **kwargs):
646647
if not g_shard_bypass_dygraph_optimizer or skip_count == 0:
647648
model.fp8_quant_weight(True)
648649
optimizer.clear_param_storage("moe_expert")
650+
optimizer.clear_param_storage("rms_linear")
651+
optimizer.clear_param_storage("memory_attn")
652+
optimizer.clear_param_storage("attn_out_project")
653+
optimizer.clear_param_storage("shared_expert")
654+
655+
self.moe_weights_name = []
656+
for param in optimizer._inner_opt._parameter_list:
657+
color = getattr(param, "color", -1)
658+
if isinstance(color, dict) and color["color"] == "moe_expert":
659+
self.moe_weights_name.append(param.name)
660+
661+
for name in self.moe_weights_name:
662+
offload(optimizer._master_weights[name])
649663

650664
skip_count += 1
665+
666+
def on_optimizer_begin(self, args, state, control, **kwargs):
667+
optimizer = kwargs["optimizer"]
668+
for name in self.moe_weights_name:
669+
reload(optimizer._master_weights[name])

paddlenlp/transformers/deepseek_v2/modeling.py

Lines changed: 25 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -88,6 +88,7 @@
8888
FP8LinearFunctionBase,
8989
FP8Mlp,
9090
cache_fp8_weight,
91+
set_parameter_color,
9192
)
9293
from .fp8_linear import Linear
9394

@@ -106,6 +107,7 @@ def swiglu(x, y=None):
106107
x, y = paddle.chunk(x, chunks=2, axis=-1)
107108
return F.silu(x) * y
108109

110+
109111
try:
110112
from paddle.incubate.nn.functional import fused_partial_rope
111113
except ImportError:
@@ -752,6 +754,7 @@ def forward(self, x):
752754

753755
class FusedNormGateFunc(paddle.autograd.PyLayer):
754756
"""recompute of postnorm and gate"""
757+
755758
_current_norm_output = None
756759
_current_invar = None
757760

@@ -799,6 +802,7 @@ def backward(ctx, d_gate_logits, d_norm_output):
799802

800803
return dx, d_rms_norm_weight, d_moe_gate_weight
801804

805+
802806
class TemporaryVarContext:
803807
def __init__(self, norm_output, invar):
804808
self.norm_output = norm_output
@@ -810,6 +814,7 @@ def __enter__(self):
810814
def __exit__(self, exc_type, exc_val, exc_tb):
811815
FusedNormGateFunc.clear_temporary_vars()
812816

817+
813818
def balance_expert_assignment(n, m, k):
814819
assert k * n % m == 0
815820
matrix = paddle.zeros((n, m), dtype=paddle.int32)
@@ -999,7 +1004,11 @@ def __init__(self, config: DeepseekV2Config, norm_weight=None, norm_eps=None):
9991004

10001005
if config.offline_quant_expert_weight and config.clear_origin_weight_when_offline_quant:
10011006
moe_grad_group = fleet.get_hybrid_communicate_group().expert_grad_comm_group
1002-
for p in self.experts.parameters():
1007+
expert_w1_list = [expert.w1 for expert in self.experts if expert is not None]
1008+
expert_w2_list = [expert.w2 for expert in self.experts if expert is not None]
1009+
for p in expert_w1_list:
1010+
setattr(p, "color", {"color": "moe_expert", "group": moe_grad_group})
1011+
for p in expert_w2_list:
10031012
setattr(p, "color", {"color": "moe_expert", "group": moe_grad_group})
10041013

10051014
self.alpha = config.aux_loss_alpha
@@ -1019,6 +1028,7 @@ def __init__(self, config: DeepseekV2Config, norm_weight=None, norm_eps=None):
10191028
self.shared_experts = DeepseekV2MLPClass(
10201029
config=config, intermediate_size=intermediate_size, is_moe=False
10211030
)
1031+
set_parameter_color([self.shared_experts.w1, self.shared_experts.w2], "shared_expert")
10221032

10231033
def fp8_quant_weight(self, batch_mode=False):
10241034
"""Quantize weights in FP8 format.
@@ -1171,7 +1181,16 @@ def qkv_pre_process(
11711181
):
11721182
if (fused_partial_rope is None) or (position_ids is not None):
11731183
return qkv_pre_process_no_fuse(
1174-
q, kv, k_pe, rotary_emb, num_heads, q_head_dim, qk_nope_head_dim, v_head_dim, qk_rope_head_dim, position_ids
1184+
q,
1185+
kv,
1186+
k_pe,
1187+
rotary_emb,
1188+
num_heads,
1189+
q_head_dim,
1190+
qk_nope_head_dim,
1191+
v_head_dim,
1192+
qk_rope_head_dim,
1193+
position_ids,
11751194
)
11761195

11771196
bsz, q_len, _ = q.shape
@@ -1712,6 +1731,7 @@ def __init__(
17121731
kv_lora_rank,
17131732
softmax_scale,
17141733
)
1734+
set_parameter_color([self.q_up_weight, self.kv_up_weight], "memory_attn")
17151735

17161736
def fp8_quant_weight(self):
17171737
cache_fp8_weight(self.q_up_weight)
@@ -1839,6 +1859,7 @@ def __init__(self, hidden_size, q_out_dim, kv_outdim, eps=1e-6) -> None:
18391859
is_bias=False,
18401860
)
18411861
self.eps = eps
1862+
set_parameter_color([self.q_down_weight], "rms_linear")
18421863

18431864
def fp8_quant_weight(self):
18441865
cache_fp8_weight(self.q_down_weight)
@@ -2237,6 +2258,8 @@ def fp8_quant_weight(self, batch_mode=False):
22372258
# logger.info(f"fp8 quant weight for mlp {type(self.mlp)}")
22382259
self.mlp.fp8_quant_weight(batch_mode)
22392260
self.self_attn.fp8_quant_weight()
2261+
elif isinstance(self.mlp, FP8Mlp):
2262+
self.self_attn.fp8_quant_weight()
22402263

22412264
def forward(
22422265
self,

paddlenlp/transformers/fp8_utils.py

Lines changed: 17 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -50,6 +50,22 @@ def swiglu(x, y=None):
5050
]
5151

5252

53+
def set_parameter_color(
54+
parameters, color, group=None, offline_quant_expert_weight=True, clear_origin_weight_when_offline_quant=True
55+
):
56+
if offline_quant_expert_weight and clear_origin_weight_when_offline_quant:
57+
if group is None:
58+
for p in parameters:
59+
if hasattr(p, "color") and p.color is not None:
60+
continue
61+
setattr(p, "color", {"color": color})
62+
else:
63+
for p in parameters:
64+
if hasattr(p, "color") and p.color is not None:
65+
continue
66+
setattr(p, "color", {"color": color, "group": group})
67+
68+
5369
def extract_first_if_tuple(x):
5470
return x[0] if isinstance(x, tuple) else x
5571

@@ -601,6 +617,7 @@ def __init__(self, in_features: int, out_features: int, bias_attr: bool = False)
601617
dtype="bfloat16",
602618
is_bias=False,
603619
)
620+
set_parameter_color([self.weight], "attn_out_project")
604621

605622
def fp8_quant_weight(self):
606623
cache_fp8_weight(self.weight)

0 commit comments

Comments
 (0)