11
11
from .args import DeepSeekV3ModelArgs
12
12
13
13
14
+ class FeedForward (nn .Module ):
15
+ """
16
+ FeedForward module
17
+
18
+ Args:
19
+ dim (int): Input dimension.
20
+ hidden_dim (int): Hidden dimension of the feedforward layer.
21
+ multiple_of (int): Value to ensure hidden dimension is a multiple of this value.
22
+ ffn_dim_multiplier (float | None): Custom multiplier for hidden dimension. Defaults to None.
23
+
24
+ Attributes:
25
+ w1 (Linear): Linear transformation for the first layer.
26
+ w2 (Linear): Linear transformation for the second layer.
27
+ w3 (Linear): Linear transformation for the third layer.
28
+
29
+ """
30
+
31
+ def __init__ (
32
+ self ,
33
+ dim : int ,
34
+ hidden_dim : int ,
35
+ ):
36
+ super ().__init__ ()
37
+ self .w1 = nn .Linear (dim , hidden_dim , bias = False )
38
+ self .w2 = nn .Linear (hidden_dim , dim , bias = False )
39
+ self .w3 = nn .Linear (dim , hidden_dim , bias = False )
40
+
41
+ def forward (self , x : torch .Tensor ) -> torch .Tensor :
42
+ return self .w2 (F .silu (self .w1 (x )) * self .w3 (x ))
43
+
44
+ def init_weights (self , init_std : float = 0.02 ):
45
+ nn .init .trunc_normal_ (self .w1 .weight , mean = 0.0 , std = 0.02 )
46
+ for linear in (self .w2 , self .w3 ):
47
+ nn .init .trunc_normal_ (linear .weight , mean = 0.0 , std = init_std )
48
+
49
+
14
50
# Reference: torchtitan/experiments/llama4/model/
15
51
class GroupedExperts (nn .Module ):
16
52
def __init__ (
@@ -212,11 +248,17 @@ def __init__(self, model_args: DeepSeekV3ModelArgs):
212
248
GroupedExperts (
213
249
dim = dim ,
214
250
hidden_dim = hidden_dim * model_args .n_shared_experts ,
215
- num_experts = 1 ,
251
+ num_experts = 1 , # Here needs to be 1 to make it equivalent to the MLP
216
252
use_grouped_mm = self .use_grouped_mm ,
217
253
)
218
254
if model_args .n_shared_experts > 0
219
255
else None
256
+ # FeedForward(
257
+ # dim=dim,
258
+ # hidden_dim=hidden_dim * model_args.n_shared_experts,
259
+ # )
260
+ # if model_args.n_shared_experts > 0
261
+ # else None
220
262
)
221
263
222
264
# auxiliary-loss-free load balancing
@@ -266,6 +308,15 @@ def forward(self, x: torch.Tensor) -> torch.Tensor:
266
308
num_local_tokens_per_expert ,
267
309
) = self .router (x .reshape (bs * slen , dim ), self .expert_bias )
268
310
311
+ print (
312
+ "In MoE, top_scores shape: " ,
313
+ top_scores .shape ,
314
+ "token_indices: " ,
315
+ token_indices .shape ,
316
+ "num_local_tokens: " ,
317
+ num_local_tokens_per_expert .shape ,
318
+ )
319
+
269
320
# will be used to update the expert bias for load balancing
270
321
self .tokens_per_expert += num_local_tokens_per_expert
271
322
@@ -311,8 +362,12 @@ def forward(self, x: torch.Tensor) -> torch.Tensor:
311
362
# NOTE: this would incur a synchronization between device and host
312
363
num_local_tokens_per_expert = num_local_tokens_per_expert .tolist ()
313
364
365
+ print ("Num local tokens per expert: " , num_local_tokens_per_expert )
314
366
# shape (bs*slen*top_k, dim)
315
- routed_output = self .experts (routed_input , num_local_tokens_per_expert )
367
+ routed_output = self .experts (
368
+ routed_input , num_local_tokens_per_expert
369
+ ) # torch.Size([16384(bsz), 256])
370
+ print ("Routed output shape: " , routed_output .shape )
316
371
routed_output = (routed_output .to (torch .float32 ) * top_scores .unsqueeze (- 1 )).to (
317
372
x .dtype
318
373
)
@@ -321,10 +376,14 @@ def forward(self, x: torch.Tensor) -> torch.Tensor:
321
376
if self .shared_expert is not None :
322
377
out = self .shared_expert (x .reshape (1 , bs * slen , dim )).reshape (
323
378
bs * slen , dim
324
- )
379
+ ) # torch.Size([16384, 256]) None
325
380
else :
326
381
out = torch .zeros_like (x .reshape (bs * slen , dim ))
327
382
383
+ print (
384
+ "Out shape: " , out .shape , out .grad .shape if out .grad is not None else None
385
+ )
386
+
328
387
out = out .scatter_add (dim = 0 , index = token_indices , src = routed_output )
329
388
out = out .reshape (bs , slen , dim )
330
389
return out
0 commit comments