Skip to content

Commit 97f7e12

Browse files
[fix] Fix perf regression caused by MoE autotuner when using DeepEPLowLatency (#6288)
Signed-off-by: Jinyang Yuan <[email protected]>
1 parent dc75779 commit 97f7e12

File tree

3 files changed

+38
-10
lines changed

3 files changed

+38
-10
lines changed

tensorrt_llm/_torch/custom_ops/torch_custom_ops.py

Lines changed: 21 additions & 8 deletions
Original file line numberDiff line numberDiff line change
@@ -39,7 +39,6 @@ def __init__(
3939
ep_rank: int,
4040
cluster_size: int,
4141
cluster_rank: int,
42-
enable_alltoall: bool,
4342
use_deepseek_fp8_block_scale: bool,
4443
use_w4a8_group_scaling: bool,
4544
use_mxfp8_act_scaling: bool,
@@ -55,7 +54,8 @@ def __init__(
5554
self.ep_rank = ep_rank
5655
self.cluster_size = cluster_size
5756
self.cluster_rank = cluster_rank
58-
self.enable_alltoall = enable_alltoall
57+
# The best tactic is estimated as if alltoall is disabled
58+
self.enable_alltoall = False
5959
self.use_deepseek_fp8_block_scale = use_deepseek_fp8_block_scale
6060
self.use_w4a8_group_scaling = use_w4a8_group_scaling
6161
self.use_mxfp8_act_scaling = use_mxfp8_act_scaling
@@ -141,24 +141,37 @@ def fused_moe(
141141
use_mxfp8_act_scaling: bool = False,
142142
min_latency_mode: bool = False,
143143
tune_max_num_tokens: int = 8192,
144+
tuner_num_tokens: Optional[int] = None,
145+
tuner_top_k: Optional[int] = None,
144146
) -> List[torch.Tensor]:
145147

146148
tuner = AutoTuner.get()
147149
MoERunner.refine_tuning_config(tune_max_num_tokens)
148150

151+
# Only the non-alltoall case is considered for profiling in the warmup phase.
152+
# Therefore, to get the correct tactics during the actual inference, the inputs to the tuner should be the same as when not using alltoall.
153+
if enable_alltoall:
154+
assert tuner_num_tokens is not None
155+
assert tuner_top_k is not None
156+
tuner_input = input[:tuner_num_tokens]
157+
else:
158+
assert tuner_num_tokens is None
159+
assert tuner_top_k is None
160+
tuner_input = input
161+
tuner_top_k = token_selected_experts.size(1)
162+
149163
# allocate workspace for profiling
150164
moe_runner = MoERunner(
151165
x_dtype=input.dtype,
152166
weight_dtype=fc1_expert_weights.dtype,
153167
output_dtype=output_dtype,
154-
top_k=token_selected_experts.size(1),
168+
top_k=tuner_top_k,
155169
tp_size=tp_size,
156170
tp_rank=tp_rank,
157171
ep_size=ep_size,
158172
ep_rank=ep_rank,
159173
cluster_size=cluster_size,
160174
cluster_rank=cluster_rank,
161-
enable_alltoall=enable_alltoall,
162175
use_deepseek_fp8_block_scale=use_deepseek_fp8_block_scale,
163176
use_w4a8_group_scaling=use_w4a8_group_scaling,
164177
use_mxfp8_act_scaling=use_mxfp8_act_scaling,
@@ -170,8 +183,8 @@ def fused_moe(
170183
[moe_runner],
171184
MoERunner.tuning_config,
172185
[
173-
input, fc1_expert_weights, fc1_expert_biases, fc2_expert_weights,
174-
fc2_expert_biases
186+
tuner_input, fc1_expert_weights, fc1_expert_biases,
187+
fc2_expert_weights, fc2_expert_biases
175188
],
176189
gemm_idx=1,
177190
)
@@ -181,8 +194,8 @@ def fused_moe(
181194
[moe_runner],
182195
MoERunner.tuning_config,
183196
[
184-
input, fc1_expert_weights, fc1_expert_biases, fc2_expert_weights,
185-
fc2_expert_biases
197+
tuner_input, fc1_expert_weights, fc1_expert_biases,
198+
fc2_expert_weights, fc2_expert_biases
186199
],
187200
gemm_idx=2,
188201
)

tensorrt_llm/_torch/modules/fused_moe/fused_moe_wide_ep.py

Lines changed: 15 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -437,6 +437,19 @@ def forward_chunk(
437437

438438
# If alltoall is disabled, we need also disable use_postquant_alltoall
439439
use_postquant_alltoall = self.use_postquant_alltoall and use_all_to_all
440+
441+
# Prepare additional information for profiling in case padding is applied when using alltoall.
442+
# Only the non-alltoall case is considered for profiling in the warmup phase.
443+
# Therefore, to get the correct tactics during the actual inference, the inputs to the tuner should be the same as when not using alltoall.
444+
if use_all_to_all:
445+
if all_rank_num_tokens is not None:
446+
tuner_num_tokens = sum(all_rank_num_tokens)
447+
else:
448+
tuner_num_tokens = x.shape[0] * self.mapping.tp_size
449+
tuner_top_k = token_selected_slots.shape[1]
450+
else:
451+
tuner_num_tokens = None
452+
tuner_top_k = None
440453
if use_all_to_all:
441454
if self.alltoall_method_type == AlltoallMethodType.MNNVL:
442455
if self.enable_dummy_allreduce:
@@ -652,6 +665,8 @@ def forward_chunk(
652665
use_w4a8_group_scaling=use_w4a8_group_scaling,
653666
min_latency_mode=False,
654667
tune_max_num_tokens=self.tune_max_num_tokens,
668+
tuner_num_tokens=tuner_num_tokens,
669+
tuner_top_k=tuner_top_k,
655670
)
656671

657672
if self.layer_load_balancer and not self.layer_load_balancer.is_static_routing(

tensorrt_llm/_torch/utils.py

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -229,7 +229,7 @@ def get_power_of_2_num_tokens_buckets(max_num_tokens) -> List[int]:
229229
num_token_buckets.append(m)
230230
m //= 2
231231

232-
return tuple(num_token_buckets)
232+
return tuple(num_token_buckets[::-1])
233233

234234

235235
def get_last_power_of_2_num_tokens_buckets(max_num_tokens) -> List[int]:
@@ -239,7 +239,7 @@ def get_last_power_of_2_num_tokens_buckets(max_num_tokens) -> List[int]:
239239
while m >= 1:
240240
num_token_buckets.append(m)
241241
m //= 2
242-
return tuple(num_token_buckets)
242+
return tuple(num_token_buckets[::-1])
243243

244244

245245
def fp4_scale_infer_shape(input_shapes: List[List[int]]):

0 commit comments

Comments
 (0)