11
11
12
12
from ...distributed import allgather
13
13
from ...model_config import ModelConfig
14
- from ...utils import AuxStreamType , Fp4QuantizedTensor
14
+ from ...utils import AuxStreamType , EventType , Fp4QuantizedTensor
15
15
from .fused_moe_cutlass import CutlassFusedMoE
16
16
from .quantization import (DeepSeekFP8BlockScalesFusedMoEMethodDeepGemm ,
17
17
MoEWeightLoadingMode , UnquantizedFusedMoEMethod )
@@ -88,6 +88,7 @@ def _masked_index_copy_group_quant_fp8(
88
88
89
89
def masked_index_copy_group_quant_fp8 (
90
90
output : torch .Tensor ,
91
+ output_s : torch .Tensor ,
91
92
input : torch .Tensor ,
92
93
start_offsets : torch .Tensor ,
93
94
row_indices : torch .Tensor ,
@@ -108,14 +109,10 @@ def masked_index_copy_group_quant_fp8(
108
109
col_size = output .shape [1 ]
109
110
dim_size = output .shape [2 ]
110
111
111
- # create padded output_s
112
112
alignment = 4
113
113
scale_dim = (dim_size + group_size - 1 ) // group_size
114
114
padded_dim_size = (scale_dim + alignment - 1 ) // alignment * alignment
115
115
padded_col_size = (col_size + alignment - 1 ) // alignment * alignment
116
- output_s = torch .zeros ((row_size , padded_dim_size // 4 , padded_col_size ),
117
- dtype = torch .int32 ,
118
- device = 'cuda' )
119
116
120
117
# get block/grid/stage/warp
121
118
num_groups = (dim_size + group_size - 1 ) // group_size
@@ -247,17 +244,14 @@ def preprocess_after_permute(expert_first_token_offset_tensor,
247
244
248
245
@nvtx_range ("[DG]" )
249
246
def deepgemm_fp8_group_blockwise_gemm (
247
+ d : torch .Tensor ,
250
248
a : torch .Tensor ,
251
249
b : torch .Tensor ,
252
250
sfa : torch .Tensor ,
253
251
sfb : torch .Tensor ,
254
252
masked_m : torch .Tensor ,
255
253
expected_m : int ,
256
254
) -> torch .Tensor :
257
- d = torch .empty ((a .shape [0 ], a .shape [1 ], b .shape [1 ]),
258
- device = b .device ,
259
- dtype = torch .bfloat16 )
260
-
261
255
# NOTES: shape must be `[G, M, K] @ [G, N, K].mT`
262
256
assert a .stride (- 1 ) == 1
263
257
assert b .stride (- 1 ) == 1
@@ -287,7 +281,16 @@ def deepgemm_fp8_group_blockwise_gemm(
287
281
masked_m ,
288
282
expected_m ,
289
283
disable_ue8m0_cast = True )
290
- return d
284
+ return
285
+
286
+
287
+ def set_strides (workspace : torch .Tensor , g : int , m : int , k : int ):
288
+ workspace = workspace [0 :g * m * k ]
289
+ workspace = workspace .as_strided (
290
+ size = (g , m , k ),
291
+ stride = (m * k , k , 1 ),
292
+ )
293
+ return workspace
291
294
292
295
293
296
class DeepGemmFusedMoE (CutlassFusedMoE ):
@@ -327,6 +330,18 @@ def __init__(
327
330
apply_router_weight_on_input : bool = False ,
328
331
layer_idx : Optional [int ] = None ,
329
332
):
333
+ if model_config .moe_max_num_tokens is None :
334
+ moe_max_num_tokens = model_config .max_num_tokens * model_config .mapping .dp_size
335
+ # The default moe_max_num_tokens is calculated from the following formula:
336
+ # max_isl = 8196, max_batch_size = 1024, mtp = 0
337
+ # max_num_tokens = ((mtp+1)*max_batch_size+max_isl+128+63)//64*64 = 9344
338
+ # moe_max_num_tokens = max_num_tokens * 2 = 18688
339
+ # It can avoid OOM for 8k/1k cases.
340
+ default_moe_max_num_tokens = 18688
341
+ if moe_max_num_tokens > default_moe_max_num_tokens :
342
+ model_config ._frozen = False
343
+ model_config .moe_max_num_tokens = default_moe_max_num_tokens
344
+ model_config ._frozen = True
330
345
331
346
super ().__init__ (
332
347
routing_method = routing_method ,
@@ -342,6 +357,37 @@ def __init__(
342
357
layer_idx = layer_idx ,
343
358
)
344
359
360
+ def get_workspace (self , m_max : int , group_size : int ):
361
+ hidden_size = self .hidden_size
362
+ intermediate_size = self .intermediate_size
363
+ num_experts = self .expert_size_per_partition
364
+
365
+ # create workspace
366
+ fp8_dim = max (hidden_size , intermediate_size )
367
+ workspace_0 = torch .empty ((num_experts * m_max * fp8_dim ),
368
+ dtype = torch .float8_e4m3fn ,
369
+ device = 'cuda' )
370
+ workspace_1 = torch .empty (
371
+ (num_experts * m_max * max (intermediate_size * 2 , hidden_size )),
372
+ dtype = torch .bfloat16 ,
373
+ device = 'cuda' )
374
+
375
+ # create workspace for scaling factors
376
+ m_padded = fp8_utils .align (m_max , 4 )
377
+ scale_k = fp8_utils .ceil_div (fp8_dim , group_size )
378
+ scale_k_padded = fp8_utils .align (scale_k , 4 )
379
+ workspace_sf = torch .empty (
380
+ (num_experts * (scale_k_padded // 4 ) * m_padded ),
381
+ dtype = torch .int32 ,
382
+ device = 'cuda' )
383
+
384
+ workspace = {
385
+ "workspace_0" : workspace_0 ,
386
+ "workspace_1" : workspace_1 ,
387
+ "workspace_sf" : workspace_sf ,
388
+ }
389
+ return workspace
390
+
345
391
def _get_quant_method (self ):
346
392
if self .quant_config is not None and self .quant_config .layer_quant_mode .has_any_quant (
347
393
exclude_kv_cache = True ):
@@ -362,6 +408,7 @@ def forward_chunk(
362
408
output_dtype : Optional [torch .dtype ] = None ,
363
409
all_rank_num_tokens : Optional [List [int ]] = None ,
364
410
use_dp_padding : Optional [bool ] = None ,
411
+ workspace : Optional [dict ] = None ,
365
412
) -> torch .Tensor :
366
413
if isinstance (x , Fp4QuantizedTensor ):
367
414
assert output_dtype is not None
@@ -437,32 +484,72 @@ def forward_chunk(
437
484
masked_m , token_to_expert_map = preprocess_after_permute (
438
485
expert_first_token_offset_tensor , permuted_data_tensor )
439
486
440
- m_max = (x .shape [0 ] + 127 ) // 128 * 128
441
487
expected_m = (token_selected_experts .numel () +
442
488
self .expert_size_per_partition -
443
489
1 ) // self .expert_size_per_partition
444
- act_input_fp8 = torch .empty (
445
- (self .expert_size_per_partition , m_max , self .hidden_size ),
446
- dtype = torch .float8_e4m3fn ,
447
- device = 'cuda' )
490
+
491
+ # padding and quantization
492
+ m_max = fp8_utils .align (x .shape [0 ], 128 )
493
+ act_input_fp8 = set_strides (workspace ["workspace_0" ],
494
+ self .expert_size_per_partition , m_max ,
495
+ self .hidden_size )
496
+
497
+ m_padded = fp8_utils .align (m_max , 4 )
498
+ scale_k = fp8_utils .ceil_div (self .hidden_size , 128 )
499
+ scale_k_padded = fp8_utils .align (scale_k , 4 )
500
+ act_input_sf = set_strides (workspace ["workspace_sf" ],
501
+ self .expert_size_per_partition ,
502
+ scale_k_padded // 4 , m_padded )
503
+
448
504
act_input_sf = masked_index_copy_group_quant_fp8 (
449
505
act_input_fp8 ,
506
+ act_input_sf ,
450
507
permuted_data_tensor ,
451
508
expert_first_token_offset_tensor ,
452
509
token_to_expert_map ,
453
510
group_size = 128 )
454
511
455
- h1 = deepgemm_fp8_group_blockwise_gemm (
512
+ # grouped gemm 1
513
+ h1 = set_strides (workspace ["workspace_1" ],
514
+ self .expert_size_per_partition , m_max ,
515
+ self .intermediate_size * 2 )
516
+
517
+ deepgemm_fp8_group_blockwise_gemm (
518
+ d = h1 ,
456
519
a = act_input_fp8 ,
457
520
b = self .w3_w1_weight ,
458
521
sfa = act_input_sf ,
459
522
sfb = self .quant_scales [0 ],
460
523
masked_m = masked_m ,
461
524
expected_m = expected_m ,
462
525
)
463
- act_input_fp8 , act_input_sf = fp8_utils .silu_and_mul_masked_post_quant_fwd (
464
- input = h1 , quant_group_size = 128 , masked_m = masked_m , scale_ue8m0 = True )
465
- h3 = deepgemm_fp8_group_blockwise_gemm (
526
+
527
+ # activation and quantization
528
+ act_input_fp8 = set_strides (workspace ["workspace_0" ],
529
+ self .expert_size_per_partition , m_max ,
530
+ self .intermediate_size )
531
+
532
+ scale_k = fp8_utils .ceil_div (self .intermediate_size , 128 )
533
+ scale_k_padded = fp8_utils .align (scale_k , 4 )
534
+ act_input_sf = set_strides (workspace ["workspace_sf" ],
535
+ self .expert_size_per_partition ,
536
+ scale_k_padded // 4 , m_padded )
537
+
538
+ act_input_sf = fp8_utils .silu_and_mul_masked_post_quant_fwd (
539
+ output = act_input_fp8 ,
540
+ output_scale = act_input_sf ,
541
+ input = h1 ,
542
+ quant_group_size = 128 ,
543
+ masked_m = masked_m ,
544
+ scale_ue8m0 = True )
545
+
546
+ # grouped gemm 2
547
+ h3 = set_strides (workspace ["workspace_1" ],
548
+ self .expert_size_per_partition , m_max ,
549
+ self .hidden_size )
550
+
551
+ deepgemm_fp8_group_blockwise_gemm (
552
+ d = h3 ,
466
553
a = act_input_fp8 ,
467
554
b = self .w2_weight ,
468
555
sfa = act_input_sf ,
@@ -471,6 +558,7 @@ def forward_chunk(
471
558
expected_m = expected_m ,
472
559
)
473
560
561
+ # gather and finalize
474
562
triton_masked_index_gather (permuted_data_tensor , h3 ,
475
563
expert_first_token_offset_tensor ,
476
564
token_to_expert_map )
@@ -495,3 +583,137 @@ def forward_chunk(
495
583
)
496
584
497
585
return final_hidden_states
586
+
587
+ def forward (
588
+ self ,
589
+ x : Union [torch .Tensor , Fp4QuantizedTensor ],
590
+ router_logits : torch .Tensor ,
591
+ do_finalize : bool = True , # used by other MoE backends
592
+ output_dtype : Optional [torch .dtype ] = None ,
593
+ all_rank_num_tokens : Optional [List [int ]] = None ,
594
+ all_rank_max_num_tokens : Optional [int ] = None ,
595
+ use_dp_padding : Optional [bool ] = None ,
596
+ ) -> torch .Tensor :
597
+ assert do_finalize , "CutlassFusedMoE does not support do_finalize=False"
598
+ if self .use_dp and self .parallel_size > 1 :
599
+ assert all_rank_num_tokens is not None
600
+ assert use_dp_padding is not None
601
+ num_rows = sum (all_rank_num_tokens )
602
+ else :
603
+ num_rows = x .shape [0 ]
604
+
605
+ # In case of num_rows is larger than max_chunk_size * 2, we need to split the input into multiple chunks.
606
+ # Because we will use two streams in chunked moe and preallocate two workspaces.
607
+ num_chunks = 1
608
+ if num_rows > self .moe_max_num_tokens * 2 :
609
+ num_chunks = (num_rows + self .moe_max_num_tokens -
610
+ 1 ) // self .moe_max_num_tokens
611
+
612
+ if use_dp_padding :
613
+ all_rank_num_tokens_padded = [all_rank_max_num_tokens
614
+ ] * len (all_rank_num_tokens )
615
+ else :
616
+ all_rank_num_tokens_padded = all_rank_num_tokens
617
+
618
+ if num_chunks == 1 :
619
+ # create workspace
620
+ num_rows = x .shape [0 ]
621
+ if self .use_dp :
622
+ num_rows = sum (all_rank_num_tokens_padded )
623
+ m_max = fp8_utils .align (num_rows , 128 )
624
+ workspace = self .get_workspace (m_max , 128 )
625
+ outputs = self .forward_chunk (
626
+ x ,
627
+ router_logits ,
628
+ output_dtype ,
629
+ all_rank_num_tokens = all_rank_num_tokens_padded ,
630
+ use_dp_padding = use_dp_padding ,
631
+ workspace = workspace )
632
+ outputs = self .reducescatter_or_allreduce (
633
+ outputs ,
634
+ all_rank_num_tokens = all_rank_num_tokens_padded ,
635
+ use_dp_padding = use_dp_padding )
636
+ else :
637
+ if self .use_dp :
638
+ all_rank_chunk_size_list = [
639
+ self .split_chunk (val , num_chunks )
640
+ for val in all_rank_num_tokens_padded
641
+ ]
642
+ all_rank_num_tokens_list = [[
643
+ val [idx_chunk ] for val in all_rank_chunk_size_list
644
+ ] for idx_chunk in range (num_chunks )]
645
+ chunk_size_list = all_rank_chunk_size_list [self .rank ]
646
+ else :
647
+ all_rank_num_tokens_list = [None ] * num_chunks
648
+ chunk_size_list = self .split_chunk (x .shape [0 ], num_chunks )
649
+
650
+ # create workspace
651
+ chunk_size_0 = sum (all_rank_num_tokens_list [0 ]
652
+ ) if self .use_dp else chunk_size_list [0 ]
653
+ chunk_size_1 = sum (all_rank_num_tokens_list [1 ]
654
+ ) if self .use_dp else chunk_size_list [1 ]
655
+ workspace_0 = self .get_workspace (fp8_utils .align (chunk_size_0 , 128 ),
656
+ 128 )
657
+ workspace_1 = self .get_workspace (fp8_utils .align (chunk_size_1 , 128 ),
658
+ 128 )
659
+
660
+ x_list = x .split (chunk_size_list )
661
+ router_logits_list = router_logits .split (chunk_size_list )
662
+
663
+ self .event_dict [EventType .Main ].record ()
664
+ with torch .cuda .stream (self .aux_stream ):
665
+ self .event_dict [EventType .Main ].wait ()
666
+
667
+ def _forward_chunk (x_ , router_logits_ , idx , workspace ):
668
+ return self .forward_chunk (
669
+ x_ ,
670
+ router_logits_ ,
671
+ all_rank_num_tokens = all_rank_num_tokens_list [idx ]
672
+ if self .use_dp else None ,
673
+ use_dp_padding = use_dp_padding ,
674
+ workspace = workspace )
675
+
676
+ def _reducescatter_or_allreduce (x_ , idx ):
677
+ return self .reducescatter_or_allreduce (
678
+ x_ ,
679
+ all_rank_num_tokens = all_rank_num_tokens_list [idx ],
680
+ use_dp_padding = use_dp_padding )
681
+
682
+ outputs_list = []
683
+ # Postpone reduce-scatter/all-reduce to the next iteration to achieve better overlap
684
+ for idx_chunk , (x , router_logits ) in enumerate (
685
+ zip (x_list , router_logits_list )):
686
+
687
+ if idx_chunk % 2 == 0 :
688
+ with torch .cuda .stream (self .aux_stream ):
689
+ outputs = _forward_chunk (x , router_logits , idx_chunk ,
690
+ workspace_0 )
691
+ if idx_chunk > 0 :
692
+ outputs_list [- 1 ] = _reducescatter_or_allreduce (
693
+ outputs_list [- 1 ], idx_chunk - 1 )
694
+ else :
695
+ outputs = _forward_chunk (x , router_logits , idx_chunk ,
696
+ workspace_1 )
697
+ with torch .cuda .stream (self .aux_stream ):
698
+ outputs_list [- 1 ] = _reducescatter_or_allreduce (
699
+ outputs_list [- 1 ], idx_chunk - 1 )
700
+
701
+ outputs_list .append (outputs )
702
+
703
+ if num_chunks % 2 == 0 :
704
+ outputs_list [- 1 ] = _reducescatter_or_allreduce (
705
+ outputs_list [- 1 ], - 1 )
706
+ else :
707
+ with torch .cuda .stream (self .aux_stream ):
708
+ outputs_list [- 1 ] = _reducescatter_or_allreduce (
709
+ outputs_list [- 1 ], - 1 )
710
+ with torch .cuda .stream (self .aux_stream ):
711
+ self .event_dict [EventType .MoeChunkingOverlap ].record ()
712
+ self .event_dict [EventType .MoeChunkingOverlap ].wait ()
713
+
714
+ outputs = torch .cat (outputs_list )
715
+
716
+ if self .use_dp and self .parallel_size > 1 :
717
+ rank = self .mapping .tp_rank
718
+ outputs = outputs [:all_rank_num_tokens [rank ]]
719
+ return outputs
0 commit comments