@@ -156,7 +156,7 @@ class TransformerBlock(nn.Module):
156
156
def __init__ (self , config : ModelArgs ) -> None :
157
157
super ().__init__ ()
158
158
self .attention = Attention (config )
159
- self .block_sparse_moe = MOEFeedForwardAOQuantizable (config )
159
+ self .block_sparse_moe = MoEFeedForward (config )
160
160
self .ffn_norm = RMSNorm (config .dim , config .norm_eps )
161
161
self .attention_norm = RMSNorm (config .dim , config .norm_eps )
162
162
@@ -225,41 +225,39 @@ def forward(
225
225
y = self .wo (y )
226
226
return y
227
227
228
+ class MoEFeedForward (nn .Module ):
229
+ def __init__ (self , config ) -> None :
230
+ super ().__init__ ()
231
+ self .gate = nn .Linear (config .dim , config .num_experts , bias = False )
232
+ self .cond_ffn = ConditionalFeedForward (config )
233
+ self .dim = config .dim
234
+ self .num_activated_experts = config .num_activated_experts
235
+ def forward (self , x : Tensor ) -> Tensor :
236
+ x = x .view (- 1 , self .dim )
237
+ # T = num_tokens, E = num_experts, D = hidden dim, A = activated experts
238
+ # x: [T, D]
239
+ scores = self .gate (x ) # [T, E]
240
+ expert_weights = F .softmax (scores , dim = - 1 )
241
+ expert_weights , expert_indices = torch .topk (expert_weights , self .num_activated_experts , dim = - 1 ) # [T, A], [T, A]
242
+ expert_weights /= expert_weights .sum (dim = - 1 , keepdim = True ) # [T, A]
243
+ expert_outs = self .cond_ffn (x , expert_indices )
244
+ return torch .einsum ('tai,ta -> ti' , expert_outs , expert_weights )
228
245
229
- # class ConditionalFeedForward(nn.Module):
230
- # def __init__(self, config):
231
- # super().__init__()
232
- # self.w1 = nn.Parameter(torch.empty(config.num_experts, config.intermediate_size, config.dim))
233
- # self.w2 = nn.Parameter(torch.empty(config.num_experts, config.dim, config.intermediate_size))
234
- # self.w3 = nn.Parameter(torch.empty(config.num_experts, config.intermediate_size, config.dim))
235
-
236
- # def forward(self, x: Tensor, expert_indices: Tensor) -> Tensor:
237
- # w1_weights = self.w1[expert_indices] # [T, A, D, D]
238
- # w3_weights = self.w3[expert_indices] # [T, A, D, D]
239
- # w2_weights = self.w2[expert_indices] # [T, A, D, D]
240
- # x1 = F.silu(torch.einsum('ti,taoi -> tao', x, w1_weights))
241
- # x3 = torch.einsum('ti, taoi -> tao', x, w3_weights)
242
- # expert_outs = torch.einsum('tao, taio -> tai', (x1 * x3), w2_weights)
243
- # return expert_outs
244
-
245
-
246
- # class MOEFeedForward(nn.Module):
247
- # def __init__(self, config) -> None:
248
- # super().__init__()
249
- # self.gate = nn.Linear(config.dim, config.num_experts, bias=False)
250
- # self.cond_ffn = ConditionalFeedForward(config)
251
- # self.dim = config.dim
252
- # self.num_activated_experts = config.num_activated_experts
253
- # def forward(self, x: Tensor) -> Tensor:
254
- # x = x.view(-1, self.dim)
255
- # # T = num_tokens, E = num_experts, D = hidden dim, A = activated experts
256
- # # x: [T, D]
257
- # scores = self.gate(x) # [T, E]
258
- # expert_weights = F.softmax(scores, dim=-1)
259
- # expert_weights, expert_indices = torch.topk(expert_weights, self.num_activated_experts, dim=-1) # [T, A], [T, A]
260
- # expert_weights /= expert_weights.sum(dim=-1, keepdim=True) # [T, A]
261
- # expert_outs = self.cond_ffn(x, expert_indices)
262
- # return torch.einsum('tai,ta -> ti', expert_outs, expert_weights)
246
+ class ConditionalFeedForward (nn .Module ):
247
+ def __init__ (self , config ):
248
+ super ().__init__ ()
249
+ self .w1 = nn .Parameter (torch .empty (config .num_experts , config .intermediate_size , config .dim ))
250
+ self .w2 = nn .Parameter (torch .empty (config .num_experts , config .dim , config .intermediate_size ))
251
+ self .w3 = nn .Parameter (torch .empty (config .num_experts , config .intermediate_size , config .dim ))
252
+
253
+ def forward (self , x : Tensor , expert_indices : Tensor ) -> Tensor :
254
+ w1_weights = self .w1 [expert_indices ] # [T, A, D, D]
255
+ w3_weights = self .w3 [expert_indices ] # [T, A, D, D]
256
+ w2_weights = self .w2 [expert_indices ] # [T, A, D, D]
257
+ x1 = F .silu (torch .einsum ('ti,taoi -> tao' , x , w1_weights ))
258
+ x3 = torch .einsum ('ti, taoi -> tao' , x , w3_weights )
259
+ expert_outs = torch .einsum ('tao, taio -> tai' , (x1 * x3 ), w2_weights )
260
+ return expert_outs
263
261
264
262
265
263
class RMSNorm (nn .Module ):
@@ -301,6 +299,8 @@ def apply_rotary_emb(x: Tensor, freqs_cis: Tensor) -> Tensor:
301
299
x_out2 = x_out2 .flatten (3 )
302
300
return x_out2 .type_as (x )
303
301
302
+ #TODO delete
303
+
304
304
305
305
# T tokens
306
306
# E experts
@@ -310,7 +310,7 @@ def apply_rotary_emb(x: Tensor, freqs_cis: Tensor) -> Tensor:
310
310
# T'(e) tokens for expert e
311
311
312
312
313
- class MOEFeedForwardAOQuantizable (nn .Module ):
313
+ class MoEFeedForwardAOQuantizable (nn .Module ):
314
314
def __init__ (self , config ) -> None :
315
315
super ().__init__ ()
316
316
self .gate = nn .Linear (config .dim , config .num_experts , bias = False )
@@ -337,7 +337,7 @@ class ConditionalFeedForwardAOQuantizable(nn.Module):
337
337
def __init__ (self , config ):
338
338
super ().__init__ ()
339
339
self .config = config
340
- self .w1 = nn .Parameter (
340
+ self .w1 = nn .Parameter ( # num exp, expert_dim, hidden_dim
341
341
torch .empty (config .num_experts , config .intermediate_size , config .dim )
342
342
) # E, I, D
343
343
self .w2 = nn .Parameter (
@@ -347,6 +347,14 @@ def __init__(self, config):
347
347
torch .empty (config .num_experts , config .intermediate_size , config .dim )
348
348
) # E, I, D
349
349
self .num_experts = config .num_experts
350
+ self .perf_is_optimized = False
351
+
352
+ def optimize_perf (self ):
353
+ self .w13 = torch .cat ((self .w1 , self .w3 ), dim = 1 )
354
+ self .w13 = torch .nn .Parameter (self .w13 .transpose (- 2 ,- 1 ).contiguous ().transpose (- 2 ,- 1 ))
355
+ self .w2 = torch .nn .Parameter (self .w2 .transpose (- 2 , - 1 ).contiguous ().transpose (- 2 , - 1 ))
356
+ del self .w1 , self .w3
357
+ self .perf_is_optimized = True
350
358
351
359
def forward (
352
360
self ,
@@ -355,8 +363,60 @@ def forward(
355
363
expert_weights : Tensor , # T, A
356
364
num_activated_experts : int ,
357
365
) -> Tensor :
366
+
367
+
358
368
num_tokens , dim = x .shape
359
- num_token_activations = num_tokens * num_activated_experts
369
+ num_token_activations = expert_indices .numel ()
370
+
371
+
372
+ ordered_token_activations = expert_indices .view (- 1 ).argsort (stable = True )
373
+ ordered_token_indices = (
374
+ ordered_token_activations .div (num_activated_experts )
375
+ .floor ()
376
+ .to (torch .int32 )
377
+ ) # [T]
378
+
379
+ indices_for_histc = expert_indices .view (- 1 ) if expert_indices .is_cuda else expert_indices .float ().view (- 1 ) # histc doesn't work on cpu for integers
380
+ num_tokens_per_expert = torch .histc (
381
+ indices_for_histc ,
382
+ bins = self .num_experts ,
383
+ min = 0 ,
384
+ max = self .num_experts ,
385
+ )
386
+ offs = num_tokens_per_expert .cumsum (dim = 0 ).to (torch .int32 )
387
+ ordered_inputs = x [ordered_token_indices ]
388
+
389
+ if self .optimized_perf :
390
+ x1 , x3 = torch ._grouped_mm (ordered_inputs , self .w13 .transpose (- 2 , - 1 ), offs ).split (self .config .intermediate_size , dim = 1 )
391
+ y1 = F .silu (x1 ) * x3
392
+ else :
393
+ x1 = F .silu (torch ._grouped_mm (ordered_inputs , self .w1 .transpose (- 2 ,- 1 ), offs ))
394
+ x3 = torch ._grouped_mm (ordered_inputs , self .w3 .transpose (- 2 ,- 1 ), offs )
395
+ y1 = x1 * x3
396
+ ordered_outs = torch ._grouped_mm (y1 , self .w2 .transpose (- 2 ,- 1 ), offs )
397
+ # ordered_outs = torch._grouped_mm(y1, self.w2, offs)
398
+
399
+ ordered_token_activation_weights = expert_weights .view (- 1 , 1 )[
400
+ ordered_token_activations
401
+ ].view (- 1 , 1 ) # [T*A, 1]
402
+ weighted_ordered_outs = (
403
+ ordered_outs * ordered_token_activation_weights
404
+ ) # [T*A, D]
405
+
406
+ # sum weighted token-activation outputs together for each token
407
+ final_out = torch .zeros_like (x ) # [T, D]
408
+ final_out = final_out .scatter_add (
409
+ dim = 0 ,
410
+ index = ordered_token_indices .unsqueeze (- 1 )
411
+ .expand (num_token_activations , dim )
412
+ .to (torch .int64 ),
413
+ src = weighted_ordered_outs ,
414
+ )
415
+
416
+ return final_out
417
+
418
+
419
+
360
420
if x .shape [0 ] == 1 and not isinstance (
361
421
self .w1 , FakeExtraDimTensor
362
422
): # only 1 token (can be done without graph breaks when compiled)
0 commit comments