33
33
import paddle .nn .functional as F
34
34
from paddle import Tensor , nn
35
35
from paddle .distributed import fleet
36
- from paddle .distributed .fleet .meta_parallel import get_rng_state_tracker
37
36
from paddle .distributed .communication .reduce import ReduceOp
37
+ from paddle .distributed .fleet .meta_parallel import get_rng_state_tracker
38
38
from paddle .distributed .fleet .recompute .recompute import recompute
39
39
from paddle .nn import BCEWithLogitsLoss , CrossEntropyLoss , MSELoss
40
40
@@ -799,7 +799,7 @@ def __init__(self, config: DeepseekV2Config):
799
799
800
800
for p in self .experts .parameters ():
801
801
setattr (p , "color" , {"color" : "moe_expert" , "group" : moe_grad_group })
802
-
802
+ setattr ( p , "is_moe_param" , True )
803
803
self .alpha = config .aux_loss_alpha
804
804
if config .n_shared_experts is not None :
805
805
intermediate_size = config .moe_intermediate_size * config .n_shared_experts
@@ -851,6 +851,7 @@ def __init__(self, config: DeepseekV2Config):
851
851
852
852
for p in self .experts .parameters ():
853
853
setattr (p , "color" , {"color" : "moe_expert" , "group" : moe_grad_group })
854
+ setattr (p , "is_moe_param" , True )
854
855
855
856
self .alpha = config .aux_loss_alpha
856
857
if config .n_shared_experts is not None :
@@ -895,7 +896,9 @@ def __init__(self, config: DeepseekV2Config, layerwise_recompute: bool = False):
895
896
self .num_heads = config .num_attention_heads
896
897
self .num_local_heads = self .num_heads
897
898
if config .tensor_parallel_degree > 1 :
898
- assert self .num_heads % config .tensor_parallel_degree == 0 , f"Attention head num ({ self .num_heads } ) is not divisible by tensor_parallel_degree ({ config .tensor_parallel_degree } )."
899
+ assert (
900
+ self .num_heads % config .tensor_parallel_degree == 0
901
+ ), f"Attention head num ({ self .num_heads } ) is not divisible by tensor_parallel_degree ({ config .tensor_parallel_degree } )."
899
902
self .num_local_heads = self .num_heads // config .tensor_parallel_degree
900
903
901
904
self .max_position_embeddings = config .max_position_embeddings
@@ -1067,7 +1070,12 @@ def forward(
1067
1070
1068
1071
if self .sequence_parallel :
1069
1072
target_query_shape = [bsz , self .seq_length , self .num_local_heads , self .q_head_dim ]
1070
- target_key_value_shape = [bsz , self .seq_length , self .num_local_heads , self .qk_nope_head_dim + self .v_head_dim ]
1073
+ target_key_value_shape = [
1074
+ bsz ,
1075
+ self .seq_length ,
1076
+ self .num_local_heads ,
1077
+ self .qk_nope_head_dim + self .v_head_dim ,
1078
+ ]
1071
1079
else :
1072
1080
target_query_shape = [0 , 0 , self .num_heads , self .q_head_dim ]
1073
1081
target_key_value_shape = [0 , 0 , self .num_heads , self .qk_nope_head_dim + self .v_head_dim ]
@@ -1153,7 +1161,6 @@ def forward(
1153
1161
if attn_output .shape != ori_shape :
1154
1162
attn_output = attn_output .reshape (ori_shape )
1155
1163
1156
-
1157
1164
if not output_attentions :
1158
1165
attn_weights = None
1159
1166
@@ -1511,7 +1518,7 @@ def forward(
1511
1518
hidden_states = self .hnorm (hidden_states )
1512
1519
nextn_hidden_state = self .enorm (nextn_hidden_state )
1513
1520
1514
- hidden_states = self .eh_proj (paddle .concat ([hidden_states , nextn_hidden_state ], axis = - 1 ))
1521
+ hidden_states = self .eh_proj (paddle .concat ([nextn_hidden_state , hidden_states ], axis = - 1 ))
1515
1522
1516
1523
layer_outputs = super (DeepseekV2MTPLayer , self ).forward (
1517
1524
hidden_states ,
@@ -1711,10 +1718,13 @@ def get_tensor_parallel_split_mappings(num_layers):
1711
1718
1712
1719
return final_actions
1713
1720
1714
- mappings = get_tensor_parallel_split_mappings (config .num_hidden_layers )
1721
+ mappings = get_tensor_parallel_split_mappings (config .num_hidden_layers + 2 )
1715
1722
1716
1723
return mappings
1717
1724
1725
+ def get_tensor_parallel_mappings (self , is_split = True ):
1726
+ return type (self )._get_tensor_parallel_mappings (self .config , is_split )
1727
+
1718
1728
def _init_weights (self , layer ):
1719
1729
return
1720
1730
if self .config .tensor_parallel_degree > 1 :
@@ -1988,7 +1998,7 @@ def forward(
1988
1998
if self .config .sequence_parallel :
1989
1999
# [bs, seq_len, num_head * head_dim] -> [bs * seq_len, num_head * head_dim]
1990
2000
bs , seq_len , hidden_size = inputs_embeds .shape
1991
- inputs_embeds = paddle .transpose (inputs_embeds , [1 , 0 , 2 ]) # [B, S, H] --> [S, B, H]
2001
+ inputs_embeds = paddle .transpose (inputs_embeds , [1 , 0 , 2 ]) # [B, S, H] --> [S, B, H]
1992
2002
# inputs_embeds = paddle.reshape(inputs_embeds, [bs * seq_len, hidden_size])
1993
2003
# [seq_len * bs / n, num_head * head_dim] (n is mp parallelism)
1994
2004
inputs_embeds = ScatterOp .apply (inputs_embeds )
@@ -2071,7 +2081,7 @@ def forward(
2071
2081
2072
2082
if self .config .sequence_parallel :
2073
2083
hidden_states = GatherOp .apply (hidden_states )
2074
- hidden_states = paddle .transpose (hidden_states , [1 , 0 , 2 ]) # [S, B, H] --> [B, S, H]
2084
+ hidden_states = paddle .transpose (hidden_states , [1 , 0 , 2 ]) # [S, B, H] --> [B, S, H]
2075
2085
# hidden_states = hidden_states.reshape([-1, seq_length, hidden_states.shape[-1]])
2076
2086
2077
2087
inputs_embeds_cur_depth = paddle .concat (
@@ -2173,7 +2183,7 @@ def add_loss(main_loss, loss):
2173
2183
seq_length = masked_lm_labels .shape [1 ]
2174
2184
2175
2185
if self .config .sequence_parallel :
2176
- masked_lm_labels = masked_lm_labels .transpose ([1 , 0 ]) # [B, S] --> [S, B]
2186
+ masked_lm_labels = masked_lm_labels .transpose ([1 , 0 ]) # [B, S] --> [S, B]
2177
2187
masked_lm_labels = ScatterOp .apply (masked_lm_labels )
2178
2188
2179
2189
loss = compute_loss (prediction_scores , masked_lm_labels )
@@ -2188,16 +2198,15 @@ def add_loss(main_loss, loss):
2188
2198
masked_lm_labels_cur_depth = masked_lm_labels_ori [:, (depth + 1 ) : (depth + 1 + seq_length )]
2189
2199
2190
2200
if self .config .sequence_parallel :
2191
- masked_lm_labels_cur_depth = masked_lm_labels_cur_depth .transpose ([1 , 0 ]) # [B, S] --> [S, B]
2201
+ masked_lm_labels_cur_depth = masked_lm_labels_cur_depth .transpose ([1 , 0 ]) # [B, S] --> [S, B]
2192
2202
masked_lm_labels_cur_depth = ScatterOp .apply (masked_lm_labels_cur_depth )
2193
2203
2194
2204
res_cur_depth = compute_loss (prediction_scores_cur_depth , masked_lm_labels_cur_depth )
2195
-
2205
+
2196
2206
if self .config .sequence_parallel :
2197
2207
res_cur_depth = res_cur_depth * self .seq_para_scale
2198
2208
dist .all_reduce (res_cur_depth , op = ReduceOp .SUM , group = self .mp_group )
2199
2209
2200
-
2201
2210
mtp_loss_res .append (res_cur_depth )
2202
2211
loss = add_loss (loss , self .config .num_nextn_predict_lambda * sum ([x for x in mtp_loss_res ]) / len (mtp_loss_res )) # fmt: skip
2203
2212
@@ -2245,9 +2254,9 @@ def __init__(self, config: DeepseekV2Config):
2245
2254
def forward (self , hidden_states , tensor_parallel_output = None ):
2246
2255
2247
2256
# if self.config.sequence_parallel:
2248
- # hidden_states = GatherOp.apply(hidden_states)
2249
- # hidden_states = paddle.transpose(hidden_states, [1, 0, 2])
2250
- # hidden_states = paddle.reshape_(hidden_states, [-1, self.seq_length, self.config.hidden_size])
2257
+ # hidden_states = GatherOp.apply(hidden_states)
2258
+ # hidden_states = paddle.transpose(hidden_states, [1, 0, 2])
2259
+ # hidden_states = paddle.reshape_(hidden_states, [-1, self.seq_length, self.config.hidden_size])
2251
2260
2252
2261
if tensor_parallel_output is None :
2253
2262
tensor_parallel_output = self .config .tensor_parallel_output
0 commit comments