88
88
FP8LinearFunctionBase ,
89
89
FP8Mlp ,
90
90
cache_fp8_weight ,
91
+ set_parameter_color ,
91
92
)
92
93
from .fp8_linear import Linear
93
94
@@ -106,6 +107,7 @@ def swiglu(x, y=None):
106
107
x , y = paddle .chunk (x , chunks = 2 , axis = - 1 )
107
108
return F .silu (x ) * y
108
109
110
+
109
111
try :
110
112
from paddle .incubate .nn .functional import fused_partial_rope
111
113
except ImportError :
@@ -752,6 +754,7 @@ def forward(self, x):
752
754
753
755
class FusedNormGateFunc (paddle .autograd .PyLayer ):
754
756
"""recompute of postnorm and gate"""
757
+
755
758
_current_norm_output = None
756
759
_current_invar = None
757
760
@@ -799,6 +802,7 @@ def backward(ctx, d_gate_logits, d_norm_output):
799
802
800
803
return dx , d_rms_norm_weight , d_moe_gate_weight
801
804
805
+
802
806
class TemporaryVarContext :
803
807
def __init__ (self , norm_output , invar ):
804
808
self .norm_output = norm_output
@@ -810,6 +814,7 @@ def __enter__(self):
810
814
def __exit__ (self , exc_type , exc_val , exc_tb ):
811
815
FusedNormGateFunc .clear_temporary_vars ()
812
816
817
+
813
818
def balance_expert_assignment (n , m , k ):
814
819
assert k * n % m == 0
815
820
matrix = paddle .zeros ((n , m ), dtype = paddle .int32 )
@@ -999,7 +1004,11 @@ def __init__(self, config: DeepseekV2Config, norm_weight=None, norm_eps=None):
999
1004
1000
1005
if config .offline_quant_expert_weight and config .clear_origin_weight_when_offline_quant :
1001
1006
moe_grad_group = fleet .get_hybrid_communicate_group ().expert_grad_comm_group
1002
- for p in self .experts .parameters ():
1007
+ expert_w1_list = [expert .w1 for expert in self .experts if expert is not None ]
1008
+ expert_w2_list = [expert .w2 for expert in self .experts if expert is not None ]
1009
+ for p in expert_w1_list :
1010
+ setattr (p , "color" , {"color" : "moe_expert" , "group" : moe_grad_group })
1011
+ for p in expert_w2_list :
1003
1012
setattr (p , "color" , {"color" : "moe_expert" , "group" : moe_grad_group })
1004
1013
1005
1014
self .alpha = config .aux_loss_alpha
@@ -1019,6 +1028,7 @@ def __init__(self, config: DeepseekV2Config, norm_weight=None, norm_eps=None):
1019
1028
self .shared_experts = DeepseekV2MLPClass (
1020
1029
config = config , intermediate_size = intermediate_size , is_moe = False
1021
1030
)
1031
+ set_parameter_color ([self .shared_experts .w1 , self .shared_experts .w2 ], "shared_expert" )
1022
1032
1023
1033
def fp8_quant_weight (self , batch_mode = False ):
1024
1034
"""Quantize weights in FP8 format.
@@ -1171,7 +1181,16 @@ def qkv_pre_process(
1171
1181
):
1172
1182
if (fused_partial_rope is None ) or (position_ids is not None ):
1173
1183
return qkv_pre_process_no_fuse (
1174
- q , kv , k_pe , rotary_emb , num_heads , q_head_dim , qk_nope_head_dim , v_head_dim , qk_rope_head_dim , position_ids
1184
+ q ,
1185
+ kv ,
1186
+ k_pe ,
1187
+ rotary_emb ,
1188
+ num_heads ,
1189
+ q_head_dim ,
1190
+ qk_nope_head_dim ,
1191
+ v_head_dim ,
1192
+ qk_rope_head_dim ,
1193
+ position_ids ,
1175
1194
)
1176
1195
1177
1196
bsz , q_len , _ = q .shape
@@ -1712,6 +1731,7 @@ def __init__(
1712
1731
kv_lora_rank ,
1713
1732
softmax_scale ,
1714
1733
)
1734
+ set_parameter_color ([self .q_up_weight , self .kv_up_weight ], "memory_attn" )
1715
1735
1716
1736
def fp8_quant_weight (self ):
1717
1737
cache_fp8_weight (self .q_up_weight )
@@ -1839,6 +1859,7 @@ def __init__(self, hidden_size, q_out_dim, kv_outdim, eps=1e-6) -> None:
1839
1859
is_bias = False ,
1840
1860
)
1841
1861
self .eps = eps
1862
+ set_parameter_color ([self .q_down_weight ], "rms_linear" )
1842
1863
1843
1864
def fp8_quant_weight (self ):
1844
1865
cache_fp8_weight (self .q_down_weight )
@@ -2237,6 +2258,8 @@ def fp8_quant_weight(self, batch_mode=False):
2237
2258
# logger.info(f"fp8 quant weight for mlp {type(self.mlp)}")
2238
2259
self .mlp .fp8_quant_weight (batch_mode )
2239
2260
self .self_attn .fp8_quant_weight ()
2261
+ elif isinstance (self .mlp , FP8Mlp ):
2262
+ self .self_attn .fp8_quant_weight ()
2240
2263
2241
2264
def forward (
2242
2265
self ,
0 commit comments