Skip to content

Commit 2bd9e77

Browse files
committed
Update default_moe_max_num_tokens.
Signed-off-by: Yuxian Qiu <[email protected]>
1 parent dc924cd commit 2bd9e77

File tree

4 files changed

+18
-8
lines changed

4 files changed

+18
-8
lines changed

tensorrt_llm/_torch/modules/fused_moe/fused_moe_cutlass.py

Lines changed: 3 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -111,10 +111,10 @@ def __init__(
111111
self.initial_local_expert_ids) == self.expert_size_per_partition
112112

113113
# The maximum number of tokens in MoE are multiplied by DP size when attention DP is enabled
114-
max_num_tokens = model_config.max_num_tokens * model_config.mapping.dp_size
115-
self.moe_max_num_tokens = model_config.moe_max_num_tokens or model_config.max_num_tokens
114+
moe_max_num_tokens = model_config.max_num_tokens * model_config.mapping.dp_size
115+
self.moe_max_num_tokens = model_config.moe_max_num_tokens or moe_max_num_tokens
116116
# The auxiliary CUDA stream and CUDA events are only used when MoE chunking is applied
117-
if self.moe_max_num_tokens < max_num_tokens:
117+
if self.moe_max_num_tokens < moe_max_num_tokens:
118118
self.aux_stream = aux_stream_dict[
119119
AuxStreamType.
120120
MoeChunkingOverlap] if aux_stream_dict is not None else torch.cuda.Stream(

tensorrt_llm/_torch/modules/fused_moe/fused_moe_deepgemm.py

Lines changed: 10 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -327,6 +327,16 @@ def __init__(
327327
apply_router_weight_on_input: bool = False,
328328
layer_idx: Optional[int] = None,
329329
):
330+
if model_config.moe_max_num_tokens is None:
331+
moe_max_num_tokens = model_config.max_num_tokens * model_config.mapping.dp_size
332+
# The default moe_max_num_tokens is calculated from the following formula:
333+
# max_isl = 8196, max_batch_size = 1024, mtp = 0
334+
# max_num_tokens = ((mtp+1)*max_batch_size+max_isl+128+63)//64*64 = 9344
335+
# moe_max_num_tokens = max_num_tokens * 2 = 18688
336+
# It can avoid OOM for 8k/1k cases.
337+
default_moe_max_num_tokens = 18688
338+
if moe_max_num_tokens > default_moe_max_num_tokens:
339+
model_config.moe_max_num_tokens = default_moe_max_num_tokens
330340

331341
super().__init__(
332342
routing_method=routing_method,

tensorrt_llm/_torch/modules/fused_moe/fused_moe_vanilla.py

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -82,8 +82,8 @@ def __init__(
8282
self.expert_size_per_partition = self.expert_end - self.expert_start
8383

8484
# The maximum number of tokens in MoE are multiplied by DP size when attention DP is enabled
85-
max_num_tokens = model_config.max_num_tokens * model_config.mapping.dp_size
86-
self.moe_max_num_tokens = model_config.moe_max_num_tokens or max_num_tokens
85+
moe_max_num_tokens = model_config.max_num_tokens * model_config.mapping.dp_size
86+
self.moe_max_num_tokens = model_config.moe_max_num_tokens or moe_max_num_tokens
8787

8888
self._weights_created = False
8989
if not model_config.skip_create_weights_in_init:

tensorrt_llm/_torch/modules/fused_moe/fused_moe_wide_ep.py

Lines changed: 3 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -151,10 +151,10 @@ def __init__(
151151
self.initial_local_expert_ids) == self.expert_size_per_partition
152152

153153
# The maximum number of tokens in MoE are multiplied by DP size when attention DP is enabled
154-
max_num_tokens = model_config.max_num_tokens * model_config.mapping.dp_size
155-
self.moe_max_num_tokens = model_config.moe_max_num_tokens or model_config.max_num_tokens
154+
moe_max_num_tokens = model_config.max_num_tokens * model_config.mapping.dp_size
155+
self.moe_max_num_tokens = model_config.moe_max_num_tokens or moe_max_num_tokens
156156
# The auxiliary CUDA stream and CUDA events are only used when MoE chunking is applied
157-
if self.moe_max_num_tokens < max_num_tokens:
157+
if self.moe_max_num_tokens < moe_max_num_tokens:
158158
self.aux_stream = aux_stream_dict[
159159
AuxStreamType.
160160
MoeChunkingOverlap] if aux_stream_dict is not None else torch.cuda.Stream(

0 commit comments

Comments
 (0)