@@ -39,7 +39,6 @@ def __init__(
39
39
ep_rank : int ,
40
40
cluster_size : int ,
41
41
cluster_rank : int ,
42
- enable_alltoall : bool ,
43
42
use_deepseek_fp8_block_scale : bool ,
44
43
use_w4a8_group_scaling : bool ,
45
44
use_mxfp8_act_scaling : bool ,
@@ -55,7 +54,8 @@ def __init__(
55
54
self .ep_rank = ep_rank
56
55
self .cluster_size = cluster_size
57
56
self .cluster_rank = cluster_rank
58
- self .enable_alltoall = enable_alltoall
57
+ # The best tactic is estimated as if alltoall is disabled
58
+ self .enable_alltoall = False
59
59
self .use_deepseek_fp8_block_scale = use_deepseek_fp8_block_scale
60
60
self .use_w4a8_group_scaling = use_w4a8_group_scaling
61
61
self .use_mxfp8_act_scaling = use_mxfp8_act_scaling
@@ -141,24 +141,37 @@ def fused_moe(
141
141
use_mxfp8_act_scaling : bool = False ,
142
142
min_latency_mode : bool = False ,
143
143
tune_max_num_tokens : int = 8192 ,
144
+ tuner_num_tokens : Optional [int ] = None ,
145
+ tuner_top_k : Optional [int ] = None ,
144
146
) -> List [torch .Tensor ]:
145
147
146
148
tuner = AutoTuner .get ()
147
149
MoERunner .refine_tuning_config (tune_max_num_tokens )
148
150
151
+ # Only the non-alltoall case is considered for profiling in the warmup phase.
152
+ # Therefore, to get the correct tactics during the actual inference, the inputs to the tuner should be the same as when not using alltoall.
153
+ if enable_alltoall :
154
+ assert tuner_num_tokens is not None
155
+ assert tuner_top_k is not None
156
+ tuner_input = input [:tuner_num_tokens ]
157
+ else :
158
+ assert tuner_num_tokens is None
159
+ assert tuner_top_k is None
160
+ tuner_input = input
161
+ tuner_top_k = token_selected_experts .size (1 )
162
+
149
163
# allocate workspace for profiling
150
164
moe_runner = MoERunner (
151
165
x_dtype = input .dtype ,
152
166
weight_dtype = fc1_expert_weights .dtype ,
153
167
output_dtype = output_dtype ,
154
- top_k = token_selected_experts . size ( 1 ) ,
168
+ top_k = tuner_top_k ,
155
169
tp_size = tp_size ,
156
170
tp_rank = tp_rank ,
157
171
ep_size = ep_size ,
158
172
ep_rank = ep_rank ,
159
173
cluster_size = cluster_size ,
160
174
cluster_rank = cluster_rank ,
161
- enable_alltoall = enable_alltoall ,
162
175
use_deepseek_fp8_block_scale = use_deepseek_fp8_block_scale ,
163
176
use_w4a8_group_scaling = use_w4a8_group_scaling ,
164
177
use_mxfp8_act_scaling = use_mxfp8_act_scaling ,
@@ -170,8 +183,8 @@ def fused_moe(
170
183
[moe_runner ],
171
184
MoERunner .tuning_config ,
172
185
[
173
- input , fc1_expert_weights , fc1_expert_biases , fc2_expert_weights ,
174
- fc2_expert_biases
186
+ tuner_input , fc1_expert_weights , fc1_expert_biases ,
187
+ fc2_expert_weights , fc2_expert_biases
175
188
],
176
189
gemm_idx = 1 ,
177
190
)
@@ -181,8 +194,8 @@ def fused_moe(
181
194
[moe_runner ],
182
195
MoERunner .tuning_config ,
183
196
[
184
- input , fc1_expert_weights , fc1_expert_biases , fc2_expert_weights ,
185
- fc2_expert_biases
197
+ tuner_input , fc1_expert_weights , fc1_expert_biases ,
198
+ fc2_expert_weights , fc2_expert_biases
186
199
],
187
200
gemm_idx = 2 ,
188
201
)
0 commit comments