Skip to content

Commit d6150c5

Browse files
committed
Only switch to chunked dg moe when num_rows is greater than self.moe_max_num_tokens * 2.
Signed-off-by: Fanrong Li <[email protected]>
1 parent 306b5c6 commit d6150c5

File tree

5 files changed

+12
-6
lines changed

5 files changed

+12
-6
lines changed

tensorrt_llm/_torch/modules/fused_moe/fused_moe_cutlass.py

Lines changed: 0 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -110,7 +110,6 @@ def __init__(
110110
assert len(
111111
self.initial_local_expert_ids) == self.expert_size_per_partition
112112

113-
max_num_tokens = model_config.max_num_tokens
114113
# The maximum number of tokens in MoE are multiplied by DP size when attention DP is enabled
115114
moe_max_num_tokens = model_config.max_num_tokens * model_config.mapping.dp_size
116115
self.moe_max_num_tokens = model_config.moe_max_num_tokens or moe_max_num_tokens

tensorrt_llm/_torch/modules/fused_moe/fused_moe_deepgemm.py

Lines changed: 8 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -339,7 +339,9 @@ def __init__(
339339
# It can avoid OOM for 8k/1k cases.
340340
default_moe_max_num_tokens = 18688
341341
if moe_max_num_tokens > default_moe_max_num_tokens:
342+
model_config._frozen = False
342343
model_config.moe_max_num_tokens = default_moe_max_num_tokens
344+
model_config._frozen = True
343345

344346
super().__init__(
345347
routing_method=routing_method,
@@ -600,9 +602,12 @@ def forward(
600602
else:
601603
num_rows = x.shape[0]
602604

603-
# in case of num_rows is larger than max_chunk_size, we need to split the input into multiple chunks
604-
num_chunks = (num_rows + self.moe_max_num_tokens -
605-
1) // self.moe_max_num_tokens
605+
# In case of num_rows is larger than max_chunk_size * 2, we need to split the input into multiple chunks.
606+
# Because we will use two streams in chunked moe and preallocate two workspaces.
607+
num_chunks = 1
608+
if num_rows > self.moe_max_num_tokens * 2:
609+
num_chunks = (num_rows + self.moe_max_num_tokens -
610+
1) // self.moe_max_num_tokens
606611

607612
if use_dp_padding:
608613
all_rank_num_tokens_padded = [all_rank_max_num_tokens

tensorrt_llm/_torch/modules/fused_moe/fused_moe_vanilla.py

Lines changed: 0 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -81,7 +81,6 @@ def __init__(
8181
self.num_experts)
8282
self.expert_size_per_partition = self.expert_end - self.expert_start
8383

84-
max_num_tokens = model_config.max_num_tokens
8584
# The maximum number of tokens in MoE are multiplied by DP size when attention DP is enabled
8685
moe_max_num_tokens = model_config.max_num_tokens * model_config.mapping.dp_size
8786
self.moe_max_num_tokens = model_config.moe_max_num_tokens or moe_max_num_tokens

tensorrt_llm/_torch/modules/fused_moe/fused_moe_wide_ep.py

Lines changed: 0 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -150,7 +150,6 @@ def __init__(
150150
assert len(
151151
self.initial_local_expert_ids) == self.expert_size_per_partition
152152

153-
max_num_tokens = model_config.max_num_tokens
154153
# The maximum number of tokens in MoE are multiplied by DP size when attention DP is enabled
155154
moe_max_num_tokens = model_config.max_num_tokens * model_config.mapping.dp_size
156155
self.moe_max_num_tokens = model_config.moe_max_num_tokens or moe_max_num_tokens

tensorrt_llm/mapping.py

Lines changed: 4 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -372,6 +372,10 @@ def node_rank(self):
372372
def local_rank(self):
373373
return self.rank % self.gpus_per_node
374374

375+
@property
376+
def dp_size(self):
377+
return self.tp_size if self.enable_attention_dp else 1
378+
375379
def has_cp(self):
376380
return self.cp_size > 1
377381

0 commit comments

Comments
 (0)