Skip to content

Commit 1bbc0e3

Browse files
lfr-0531yuxianq
andauthored
[None][fix] Pre-allocate workspaces for DeepGEMM MoE to avoid frequent cudaFree/cudaMalloc (#6811)
Signed-off-by: Fanrong Li <[email protected]> Signed-off-by: Yuxian Qiu <[email protected]> Co-authored-by: Yuxian Qiu <[email protected]>
1 parent 47806f0 commit 1bbc0e3

File tree

6 files changed

+256
-47
lines changed

6 files changed

+256
-47
lines changed

tensorrt_llm/_torch/modules/fused_moe/fused_moe_cutlass.py

Lines changed: 3 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -110,13 +110,11 @@ def __init__(
110110
assert len(
111111
self.initial_local_expert_ids) == self.expert_size_per_partition
112112

113-
max_num_tokens = model_config.max_num_tokens
114113
# The maximum number of tokens in MoE are multiplied by DP size when attention DP is enabled
115-
if self.use_dp:
116-
max_num_tokens *= model_config.mapping.world_size
117-
self.moe_max_num_tokens = model_config.moe_max_num_tokens or max_num_tokens
114+
moe_max_num_tokens = model_config.max_num_tokens * model_config.mapping.dp_size
115+
self.moe_max_num_tokens = model_config.moe_max_num_tokens or moe_max_num_tokens
118116
# The auxiliary CUDA stream and CUDA events are only used when MoE chunking is applied
119-
if self.moe_max_num_tokens < max_num_tokens:
117+
if self.moe_max_num_tokens < moe_max_num_tokens:
120118
self.aux_stream = aux_stream_dict[
121119
AuxStreamType.
122120
MoeChunkingOverlap] if aux_stream_dict is not None else torch.cuda.Stream(

tensorrt_llm/_torch/modules/fused_moe/fused_moe_deepgemm.py

Lines changed: 241 additions & 19 deletions
Original file line numberDiff line numberDiff line change
@@ -11,7 +11,7 @@
1111

1212
from ...distributed import allgather
1313
from ...model_config import ModelConfig
14-
from ...utils import AuxStreamType, Fp4QuantizedTensor
14+
from ...utils import AuxStreamType, EventType, Fp4QuantizedTensor
1515
from .fused_moe_cutlass import CutlassFusedMoE
1616
from .quantization import (DeepSeekFP8BlockScalesFusedMoEMethodDeepGemm,
1717
MoEWeightLoadingMode, UnquantizedFusedMoEMethod)
@@ -88,6 +88,7 @@ def _masked_index_copy_group_quant_fp8(
8888

8989
def masked_index_copy_group_quant_fp8(
9090
output: torch.Tensor,
91+
output_s: torch.Tensor,
9192
input: torch.Tensor,
9293
start_offsets: torch.Tensor,
9394
row_indices: torch.Tensor,
@@ -108,14 +109,10 @@ def masked_index_copy_group_quant_fp8(
108109
col_size = output.shape[1]
109110
dim_size = output.shape[2]
110111

111-
# create padded output_s
112112
alignment = 4
113113
scale_dim = (dim_size + group_size - 1) // group_size
114114
padded_dim_size = (scale_dim + alignment - 1) // alignment * alignment
115115
padded_col_size = (col_size + alignment - 1) // alignment * alignment
116-
output_s = torch.zeros((row_size, padded_dim_size // 4, padded_col_size),
117-
dtype=torch.int32,
118-
device='cuda')
119116

120117
# get block/grid/stage/warp
121118
num_groups = (dim_size + group_size - 1) // group_size
@@ -247,17 +244,14 @@ def preprocess_after_permute(expert_first_token_offset_tensor,
247244

248245
@nvtx_range("[DG]")
249246
def deepgemm_fp8_group_blockwise_gemm(
247+
d: torch.Tensor,
250248
a: torch.Tensor,
251249
b: torch.Tensor,
252250
sfa: torch.Tensor,
253251
sfb: torch.Tensor,
254252
masked_m: torch.Tensor,
255253
expected_m: int,
256254
) -> torch.Tensor:
257-
d = torch.empty((a.shape[0], a.shape[1], b.shape[1]),
258-
device=b.device,
259-
dtype=torch.bfloat16)
260-
261255
# NOTES: shape must be `[G, M, K] @ [G, N, K].mT`
262256
assert a.stride(-1) == 1
263257
assert b.stride(-1) == 1
@@ -287,7 +281,16 @@ def deepgemm_fp8_group_blockwise_gemm(
287281
masked_m,
288282
expected_m,
289283
disable_ue8m0_cast=True)
290-
return d
284+
return
285+
286+
287+
def set_strides(workspace: torch.Tensor, g: int, m: int, k: int):
288+
workspace = workspace[0:g * m * k]
289+
workspace = workspace.as_strided(
290+
size=(g, m, k),
291+
stride=(m * k, k, 1),
292+
)
293+
return workspace
291294

292295

293296
class DeepGemmFusedMoE(CutlassFusedMoE):
@@ -327,6 +330,18 @@ def __init__(
327330
apply_router_weight_on_input: bool = False,
328331
layer_idx: Optional[int] = None,
329332
):
333+
if model_config.moe_max_num_tokens is None:
334+
moe_max_num_tokens = model_config.max_num_tokens * model_config.mapping.dp_size
335+
# The default moe_max_num_tokens is calculated from the following formula:
336+
# max_isl = 8196, max_batch_size = 1024, mtp = 0
337+
# max_num_tokens = ((mtp+1)*max_batch_size+max_isl+128+63)//64*64 = 9344
338+
# moe_max_num_tokens = max_num_tokens * 2 = 18688
339+
# It can avoid OOM for 8k/1k cases.
340+
default_moe_max_num_tokens = 18688
341+
if moe_max_num_tokens > default_moe_max_num_tokens:
342+
model_config._frozen = False
343+
model_config.moe_max_num_tokens = default_moe_max_num_tokens
344+
model_config._frozen = True
330345

331346
super().__init__(
332347
routing_method=routing_method,
@@ -342,6 +357,37 @@ def __init__(
342357
layer_idx=layer_idx,
343358
)
344359

360+
def get_workspace(self, m_max: int, group_size: int):
361+
hidden_size = self.hidden_size
362+
intermediate_size = self.intermediate_size
363+
num_experts = self.expert_size_per_partition
364+
365+
# create workspace
366+
fp8_dim = max(hidden_size, intermediate_size)
367+
workspace_0 = torch.empty((num_experts * m_max * fp8_dim),
368+
dtype=torch.float8_e4m3fn,
369+
device='cuda')
370+
workspace_1 = torch.empty(
371+
(num_experts * m_max * max(intermediate_size * 2, hidden_size)),
372+
dtype=torch.bfloat16,
373+
device='cuda')
374+
375+
# create workspace for scaling factors
376+
m_padded = fp8_utils.align(m_max, 4)
377+
scale_k = fp8_utils.ceil_div(fp8_dim, group_size)
378+
scale_k_padded = fp8_utils.align(scale_k, 4)
379+
workspace_sf = torch.empty(
380+
(num_experts * (scale_k_padded // 4) * m_padded),
381+
dtype=torch.int32,
382+
device='cuda')
383+
384+
workspace = {
385+
"workspace_0": workspace_0,
386+
"workspace_1": workspace_1,
387+
"workspace_sf": workspace_sf,
388+
}
389+
return workspace
390+
345391
def _get_quant_method(self):
346392
if self.quant_config is not None and self.quant_config.layer_quant_mode.has_any_quant(
347393
exclude_kv_cache=True):
@@ -362,6 +408,7 @@ def forward_chunk(
362408
output_dtype: Optional[torch.dtype] = None,
363409
all_rank_num_tokens: Optional[List[int]] = None,
364410
use_dp_padding: Optional[bool] = None,
411+
workspace: Optional[dict] = None,
365412
) -> torch.Tensor:
366413
if isinstance(x, Fp4QuantizedTensor):
367414
assert output_dtype is not None
@@ -437,32 +484,72 @@ def forward_chunk(
437484
masked_m, token_to_expert_map = preprocess_after_permute(
438485
expert_first_token_offset_tensor, permuted_data_tensor)
439486

440-
m_max = (x.shape[0] + 127) // 128 * 128
441487
expected_m = (token_selected_experts.numel() +
442488
self.expert_size_per_partition -
443489
1) // self.expert_size_per_partition
444-
act_input_fp8 = torch.empty(
445-
(self.expert_size_per_partition, m_max, self.hidden_size),
446-
dtype=torch.float8_e4m3fn,
447-
device='cuda')
490+
491+
# padding and quantization
492+
m_max = fp8_utils.align(x.shape[0], 128)
493+
act_input_fp8 = set_strides(workspace["workspace_0"],
494+
self.expert_size_per_partition, m_max,
495+
self.hidden_size)
496+
497+
m_padded = fp8_utils.align(m_max, 4)
498+
scale_k = fp8_utils.ceil_div(self.hidden_size, 128)
499+
scale_k_padded = fp8_utils.align(scale_k, 4)
500+
act_input_sf = set_strides(workspace["workspace_sf"],
501+
self.expert_size_per_partition,
502+
scale_k_padded // 4, m_padded)
503+
448504
act_input_sf = masked_index_copy_group_quant_fp8(
449505
act_input_fp8,
506+
act_input_sf,
450507
permuted_data_tensor,
451508
expert_first_token_offset_tensor,
452509
token_to_expert_map,
453510
group_size=128)
454511

455-
h1 = deepgemm_fp8_group_blockwise_gemm(
512+
# grouped gemm 1
513+
h1 = set_strides(workspace["workspace_1"],
514+
self.expert_size_per_partition, m_max,
515+
self.intermediate_size * 2)
516+
517+
deepgemm_fp8_group_blockwise_gemm(
518+
d=h1,
456519
a=act_input_fp8,
457520
b=self.w3_w1_weight,
458521
sfa=act_input_sf,
459522
sfb=self.quant_scales[0],
460523
masked_m=masked_m,
461524
expected_m=expected_m,
462525
)
463-
act_input_fp8, act_input_sf = fp8_utils.silu_and_mul_masked_post_quant_fwd(
464-
input=h1, quant_group_size=128, masked_m=masked_m, scale_ue8m0=True)
465-
h3 = deepgemm_fp8_group_blockwise_gemm(
526+
527+
# activation and quantization
528+
act_input_fp8 = set_strides(workspace["workspace_0"],
529+
self.expert_size_per_partition, m_max,
530+
self.intermediate_size)
531+
532+
scale_k = fp8_utils.ceil_div(self.intermediate_size, 128)
533+
scale_k_padded = fp8_utils.align(scale_k, 4)
534+
act_input_sf = set_strides(workspace["workspace_sf"],
535+
self.expert_size_per_partition,
536+
scale_k_padded // 4, m_padded)
537+
538+
act_input_sf = fp8_utils.silu_and_mul_masked_post_quant_fwd(
539+
output=act_input_fp8,
540+
output_scale=act_input_sf,
541+
input=h1,
542+
quant_group_size=128,
543+
masked_m=masked_m,
544+
scale_ue8m0=True)
545+
546+
# grouped gemm 2
547+
h3 = set_strides(workspace["workspace_1"],
548+
self.expert_size_per_partition, m_max,
549+
self.hidden_size)
550+
551+
deepgemm_fp8_group_blockwise_gemm(
552+
d=h3,
466553
a=act_input_fp8,
467554
b=self.w2_weight,
468555
sfa=act_input_sf,
@@ -471,6 +558,7 @@ def forward_chunk(
471558
expected_m=expected_m,
472559
)
473560

561+
# gather and finalize
474562
triton_masked_index_gather(permuted_data_tensor, h3,
475563
expert_first_token_offset_tensor,
476564
token_to_expert_map)
@@ -495,3 +583,137 @@ def forward_chunk(
495583
)
496584

497585
return final_hidden_states
586+
587+
def forward(
588+
self,
589+
x: Union[torch.Tensor, Fp4QuantizedTensor],
590+
router_logits: torch.Tensor,
591+
do_finalize: bool = True, # used by other MoE backends
592+
output_dtype: Optional[torch.dtype] = None,
593+
all_rank_num_tokens: Optional[List[int]] = None,
594+
all_rank_max_num_tokens: Optional[int] = None,
595+
use_dp_padding: Optional[bool] = None,
596+
) -> torch.Tensor:
597+
assert do_finalize, "CutlassFusedMoE does not support do_finalize=False"
598+
if self.use_dp and self.parallel_size > 1:
599+
assert all_rank_num_tokens is not None
600+
assert use_dp_padding is not None
601+
num_rows = sum(all_rank_num_tokens)
602+
else:
603+
num_rows = x.shape[0]
604+
605+
# In case of num_rows is larger than max_chunk_size * 2, we need to split the input into multiple chunks.
606+
# Because we will use two streams in chunked moe and preallocate two workspaces.
607+
num_chunks = 1
608+
if num_rows > self.moe_max_num_tokens * 2:
609+
num_chunks = (num_rows + self.moe_max_num_tokens -
610+
1) // self.moe_max_num_tokens
611+
612+
if use_dp_padding:
613+
all_rank_num_tokens_padded = [all_rank_max_num_tokens
614+
] * len(all_rank_num_tokens)
615+
else:
616+
all_rank_num_tokens_padded = all_rank_num_tokens
617+
618+
if num_chunks == 1:
619+
# create workspace
620+
num_rows = x.shape[0]
621+
if self.use_dp:
622+
num_rows = sum(all_rank_num_tokens_padded)
623+
m_max = fp8_utils.align(num_rows, 128)
624+
workspace = self.get_workspace(m_max, 128)
625+
outputs = self.forward_chunk(
626+
x,
627+
router_logits,
628+
output_dtype,
629+
all_rank_num_tokens=all_rank_num_tokens_padded,
630+
use_dp_padding=use_dp_padding,
631+
workspace=workspace)
632+
outputs = self.reducescatter_or_allreduce(
633+
outputs,
634+
all_rank_num_tokens=all_rank_num_tokens_padded,
635+
use_dp_padding=use_dp_padding)
636+
else:
637+
if self.use_dp:
638+
all_rank_chunk_size_list = [
639+
self.split_chunk(val, num_chunks)
640+
for val in all_rank_num_tokens_padded
641+
]
642+
all_rank_num_tokens_list = [[
643+
val[idx_chunk] for val in all_rank_chunk_size_list
644+
] for idx_chunk in range(num_chunks)]
645+
chunk_size_list = all_rank_chunk_size_list[self.rank]
646+
else:
647+
all_rank_num_tokens_list = [None] * num_chunks
648+
chunk_size_list = self.split_chunk(x.shape[0], num_chunks)
649+
650+
# create workspace
651+
chunk_size_0 = sum(all_rank_num_tokens_list[0]
652+
) if self.use_dp else chunk_size_list[0]
653+
chunk_size_1 = sum(all_rank_num_tokens_list[1]
654+
) if self.use_dp else chunk_size_list[1]
655+
workspace_0 = self.get_workspace(fp8_utils.align(chunk_size_0, 128),
656+
128)
657+
workspace_1 = self.get_workspace(fp8_utils.align(chunk_size_1, 128),
658+
128)
659+
660+
x_list = x.split(chunk_size_list)
661+
router_logits_list = router_logits.split(chunk_size_list)
662+
663+
self.event_dict[EventType.Main].record()
664+
with torch.cuda.stream(self.aux_stream):
665+
self.event_dict[EventType.Main].wait()
666+
667+
def _forward_chunk(x_, router_logits_, idx, workspace):
668+
return self.forward_chunk(
669+
x_,
670+
router_logits_,
671+
all_rank_num_tokens=all_rank_num_tokens_list[idx]
672+
if self.use_dp else None,
673+
use_dp_padding=use_dp_padding,
674+
workspace=workspace)
675+
676+
def _reducescatter_or_allreduce(x_, idx):
677+
return self.reducescatter_or_allreduce(
678+
x_,
679+
all_rank_num_tokens=all_rank_num_tokens_list[idx],
680+
use_dp_padding=use_dp_padding)
681+
682+
outputs_list = []
683+
# Postpone reduce-scatter/all-reduce to the next iteration to achieve better overlap
684+
for idx_chunk, (x, router_logits) in enumerate(
685+
zip(x_list, router_logits_list)):
686+
687+
if idx_chunk % 2 == 0:
688+
with torch.cuda.stream(self.aux_stream):
689+
outputs = _forward_chunk(x, router_logits, idx_chunk,
690+
workspace_0)
691+
if idx_chunk > 0:
692+
outputs_list[-1] = _reducescatter_or_allreduce(
693+
outputs_list[-1], idx_chunk - 1)
694+
else:
695+
outputs = _forward_chunk(x, router_logits, idx_chunk,
696+
workspace_1)
697+
with torch.cuda.stream(self.aux_stream):
698+
outputs_list[-1] = _reducescatter_or_allreduce(
699+
outputs_list[-1], idx_chunk - 1)
700+
701+
outputs_list.append(outputs)
702+
703+
if num_chunks % 2 == 0:
704+
outputs_list[-1] = _reducescatter_or_allreduce(
705+
outputs_list[-1], -1)
706+
else:
707+
with torch.cuda.stream(self.aux_stream):
708+
outputs_list[-1] = _reducescatter_or_allreduce(
709+
outputs_list[-1], -1)
710+
with torch.cuda.stream(self.aux_stream):
711+
self.event_dict[EventType.MoeChunkingOverlap].record()
712+
self.event_dict[EventType.MoeChunkingOverlap].wait()
713+
714+
outputs = torch.cat(outputs_list)
715+
716+
if self.use_dp and self.parallel_size > 1:
717+
rank = self.mapping.tp_rank
718+
outputs = outputs[:all_rank_num_tokens[rank]]
719+
return outputs

tensorrt_llm/_torch/modules/fused_moe/fused_moe_vanilla.py

Lines changed: 2 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -81,13 +81,9 @@ def __init__(
8181
self.num_experts)
8282
self.expert_size_per_partition = self.expert_end - self.expert_start
8383

84-
max_num_tokens = model_config.max_num_tokens
8584
# The maximum number of tokens in MoE are multiplied by DP size when attention DP is enabled
86-
if self.use_dp:
87-
max_num_tokens *= model_config.mapping.world_size
88-
self.moe_max_num_tokens = (model_config.moe_max_num_tokens
89-
if model_config.moe_max_num_tokens
90-
is not None else max_num_tokens)
85+
moe_max_num_tokens = model_config.max_num_tokens * model_config.mapping.dp_size
86+
self.moe_max_num_tokens = model_config.moe_max_num_tokens or moe_max_num_tokens
9187

9288
self._weights_created = False
9389
if not model_config.skip_create_weights_in_init:

0 commit comments

Comments
 (0)